mindspore 2.7.0__cp310-cp310-win_amd64.whl → 2.7.0rc1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +1 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +2 -2
- mindspore/_extends/builtin_operations.py +3 -3
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +3 -3
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -0
- mindspore/_extends/parse/parser.py +22 -28
- mindspore/_extends/parse/standard_method.py +1 -15
- mindspore/_extends/pijit/pijit_func_white_list.py +5 -2
- mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
- mindspore/amp.py +18 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/common/__init__.py +12 -18
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +38 -102
- mindspore/common/_utils.py +1 -9
- mindspore/common/api.py +106 -155
- mindspore/common/{dynamic_shape/auto_dynamic_shape.py → auto_dynamic_shape.py} +23 -17
- mindspore/common/dtype.py +57 -98
- mindspore/common/dump.py +1 -1
- mindspore/common/file_system.py +9 -59
- mindspore/common/hook_handle.py +3 -22
- mindspore/common/np_dtype.py +3 -3
- mindspore/common/parameter.py +20 -4
- mindspore/common/recompute.py +4 -2
- mindspore/common/tensor.py +52 -38
- mindspore/communication/_hccl_management.py +297 -0
- mindspore/context.py +21 -15
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/transforms.py +1 -1
- mindspore/dataset/core/config.py +1 -35
- mindspore/dataset/engine/datasets.py +315 -330
- mindspore/dataset/engine/datasets_user_defined.py +22 -38
- mindspore/dataset/transforms/c_transforms.py +2 -2
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/vision/__init__.py +1 -1
- mindspore/dataset/vision/py_transforms.py +8 -8
- mindspore/dataset/vision/transforms.py +5 -17
- mindspore/dataset/vision/utils.py +21 -632
- mindspore/device_context/ascend/op_tuning.py +1 -35
- mindspore/dnnl.dll +0 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -3
- mindspore/include/api/cell.h +4 -28
- mindspore/include/api/cfg.h +7 -24
- mindspore/include/api/context.h +0 -1
- mindspore/include/api/delegate.h +2 -0
- mindspore/include/api/dual_abi_helper.h +19 -100
- mindspore/include/api/graph.h +1 -14
- mindspore/include/api/kernel.h +3 -16
- mindspore/include/api/kernel_api.h +1 -9
- mindspore/include/api/metrics/accuracy.h +0 -9
- mindspore/include/api/model.h +1 -5
- mindspore/include/api/model_group.h +0 -4
- mindspore/include/api/model_parallel_runner.h +0 -2
- mindspore/include/api/status.h +10 -48
- mindspore/include/api/types.h +1 -6
- mindspore/include/dataset/constants.h +0 -9
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/cifar10.py +2 -3
- mindspore/mindrecord/tools/cifar10_to_mr.py +5 -5
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mindspore_ops_host.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/distributed/__init__.py +0 -4
- mindspore/mint/distributed/distributed.py +14 -217
- mindspore/mint/nn/layer/_functions.py +2 -1
- mindspore/mint/nn/layer/conv.py +6 -6
- mindspore/mint/nn/layer/normalization.py +3 -3
- mindspore/nn/cell.py +174 -216
- mindspore/nn/layer/activation.py +2 -4
- mindspore/nn/layer/basic.py +13 -7
- mindspore/nn/layer/image.py +1 -1
- mindspore/nn/optim/adam.py +3 -1
- mindspore/nn/optim/lamb.py +3 -1
- mindspore/nn/optim/tft_wrapper.py +3 -2
- mindspore/nn/probability/distribution/_utils/utils.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +5 -39
- mindspore/nn/wrap/grad_reducer.py +15 -0
- mindspore/numpy/array_creations.py +2 -2
- mindspore/numpy/utils_const.py +1 -1
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_op_impl/cpu/__init__.py +0 -1
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +2 -12
- mindspore/ops/auto_generate/gen_extend_func.py +4 -4
- mindspore/ops/auto_generate/gen_ops_def.py +16 -290
- mindspore/ops/auto_generate/gen_ops_prim.py +76 -563
- mindspore/ops/composite/base.py +1 -1
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/function/__init__.py +0 -1
- mindspore/ops/function/array_func.py +6 -10
- mindspore/ops/function/debug_func.py +2 -4
- mindspore/ops/function/grad/grad_func.py +12 -4
- mindspore/ops/function/math_func.py +32 -44
- mindspore/ops/function/nn_func.py +20 -18
- mindspore/ops/functional.py +1 -2
- mindspore/ops/functional_overload.py +12 -23
- mindspore/ops/operations/_inner_ops.py +12 -11
- mindspore/ops/operations/array_ops.py +50 -4
- mindspore/ops/operations/comm_ops.py +15 -1
- mindspore/ops/operations/custom_ops.py +4 -10
- mindspore/ops/operations/debug_ops.py +6 -6
- mindspore/ops/operations/manually_defined/ops_def.py +12 -12
- mindspore/ops/operations/math_ops.py +5 -5
- mindspore/ops/operations/nn_ops.py +1 -1
- mindspore/ops/primitive.py +10 -3
- mindspore/ops/tensor_method.py +7 -16
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +16 -0
- mindspore/parallel/_auto_parallel_context.py +15 -5
- mindspore/parallel/_parallel_serialization.py +2 -3
- mindspore/parallel/_ps_context.py +2 -2
- mindspore/parallel/_transformer/transformer.py +4 -4
- mindspore/parallel/_utils.py +11 -5
- mindspore/parallel/auto_parallel.py +9 -23
- mindspore/parallel/checkpoint_transform.py +0 -2
- mindspore/parallel/cluster/process_entity/_api.py +1 -4
- mindspore/parallel/cluster/run.py +3 -5
- mindspore/parallel/function/reshard_func.py +5 -6
- mindspore/parallel/nn/parallel_cell_wrapper.py +3 -40
- mindspore/parallel/nn/parallel_grad_reducer.py +8 -0
- mindspore/parallel/shard.py +21 -7
- mindspore/parallel/transform_safetensors.py +4 -10
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +9 -10
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
- mindspore/profiler/common/msprof_cmd_tool.py +2 -2
- mindspore/profiler/common/path_manager.py +0 -9
- mindspore/profiler/common/profiler_context.py +2 -25
- mindspore/profiler/common/profiler_meta_data.py +0 -1
- mindspore/profiler/common/profiler_op_analyse.py +6 -10
- mindspore/{ops/_op_impl/cpu/joinedstr_op.py → profiler/common/validator/__init__.py} +1 -15
- mindspore/profiler/common/validator/validate_path.py +84 -0
- mindspore/profiler/dynamic_profiler.py +46 -91
- mindspore/profiler/envprofiler.py +5 -30
- mindspore/profiler/experimental_config.py +1 -16
- mindspore/profiler/platform/cpu_profiler.py +4 -10
- mindspore/profiler/platform/npu_profiler.py +1 -1
- mindspore/profiler/profiler.py +145 -193
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +2 -2
- mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
- mindspore/runtime/__init__.py +4 -6
- mindspore/runtime/executor.py +0 -27
- mindspore/runtime/memory.py +0 -1
- mindspore/runtime/thread_bind_core.py +1 -1
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/_utils.py +3 -3
- mindspore/train/amp.py +3 -0
- mindspore/train/callback/_callback.py +1 -2
- mindspore/train/callback/_checkpoint.py +8 -1
- mindspore/train/callback/_flops_collector.py +6 -10
- mindspore/train/callback/_train_fault_tolerance.py +7 -3
- mindspore/train/data_sink.py +4 -4
- mindspore/train/dataset_helper.py +5 -5
- mindspore/train/model.py +20 -4
- mindspore/train/serialization.py +15 -35
- mindspore/train/train_thor/model_thor.py +2 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/utils.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/METADATA +1 -1
- {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/RECORD +193 -192
- mindspore/_extends/parallel_compile/akg_compiler/custom.py +0 -1109
- mindspore/common/dynamic_shape/__init__.py +0 -0
- mindspore/common/dynamic_shape/enable_dynamic.py +0 -197
- /mindspore/common/{dynamic_shape/_auto_dynamic.py → _auto_dynamic.py} +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py
CHANGED
|
@@ -39,11 +39,10 @@ from typing import (
|
|
|
39
39
|
|
|
40
40
|
import weakref
|
|
41
41
|
import mindspore as ms
|
|
42
|
-
import mindspore.ops as ops
|
|
43
42
|
from mindspore._checkparam import args_type_check, check_hook_fn
|
|
44
|
-
from mindspore.common.
|
|
43
|
+
from mindspore.common._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic
|
|
45
44
|
from mindspore import log as logger
|
|
46
|
-
from mindspore.common.hook_handle import HookHandle
|
|
45
|
+
from mindspore.common.hook_handle import HookHandle
|
|
47
46
|
from mindspore import context
|
|
48
47
|
from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
|
|
49
48
|
from mindspore import _checkparam as Validator
|
|
@@ -93,8 +92,9 @@ def register_cell_buffer_registration_hook(hook: Callable[..., None],):
|
|
|
93
92
|
A handle that can be used to remove the added hook by calling
|
|
94
93
|
`handle.remove()`.
|
|
95
94
|
"""
|
|
96
|
-
|
|
97
|
-
|
|
95
|
+
from mindspore.utils.hooks import _RemovableHandle
|
|
96
|
+
handle = _RemovableHandle(_global_buffer_registration_hooks)
|
|
97
|
+
_global_buffer_registration_hooks[handle.id] = hook
|
|
98
98
|
return handle
|
|
99
99
|
|
|
100
100
|
|
|
@@ -155,8 +155,7 @@ class Cell(Cell_):
|
|
|
155
155
|
IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_create_time',
|
|
156
156
|
'_func_graph_flags', '_parameter_layout_dict', '_params_list', '_phase', '_bprop_debug',
|
|
157
157
|
'_forward_pre_hook', '_forward_hook', '_backward_pre_hook', '_backward_hook',
|
|
158
|
-
'_cell_backward_pre_hook', '_cell_backward_hook', '_param_prefix',
|
|
159
|
-
'requires_grad', 'cell_type', '_in_strategy', '_out_strategy']
|
|
158
|
+
'_cell_backward_pre_hook', '_cell_backward_hook', '_param_prefix', 'requires_grad', 'cell_type']
|
|
160
159
|
total_instance_count = 0
|
|
161
160
|
_buffers: Dict[str, Optional[Tensor]]
|
|
162
161
|
global_cells = weakref.WeakKeyDictionary()
|
|
@@ -207,7 +206,6 @@ class Cell(Cell_):
|
|
|
207
206
|
super().__setattr__("mixed_precision_type", None)
|
|
208
207
|
super().__setattr__("_lazy_construct_sig", None)
|
|
209
208
|
super().__setattr__("_jit_graph_name", '')
|
|
210
|
-
super().__setattr__("_compiled", False)
|
|
211
209
|
init_pipeline()
|
|
212
210
|
|
|
213
211
|
# call gc to release GE session resources used by non-used cell objects
|
|
@@ -241,8 +239,6 @@ class Cell(Cell_):
|
|
|
241
239
|
super().__setattr__("_amp_level", "")
|
|
242
240
|
super().__setattr__("_init_flag", False)
|
|
243
241
|
super().__setattr__("_shard_fn", None)
|
|
244
|
-
super().__setattr__("_in_strategy", None)
|
|
245
|
-
super().__setattr__("_out_strategy", None)
|
|
246
242
|
super().__setattr__("has_bprop", False)
|
|
247
243
|
if hasattr(self, "bprop"):
|
|
248
244
|
super().__setattr__("has_bprop", True)
|
|
@@ -430,13 +426,6 @@ class Cell(Cell_):
|
|
|
430
426
|
"""
|
|
431
427
|
return self._bprop_debug
|
|
432
428
|
|
|
433
|
-
@property
|
|
434
|
-
def compiled(self):
|
|
435
|
-
"""
|
|
436
|
-
Get whether `Cell` is compiled in graph mode.
|
|
437
|
-
"""
|
|
438
|
-
return self._compiled
|
|
439
|
-
|
|
440
429
|
@bprop_debug.setter
|
|
441
430
|
def bprop_debug(self, value):
|
|
442
431
|
"""
|
|
@@ -557,23 +546,10 @@ class Cell(Cell_):
|
|
|
557
546
|
|
|
558
547
|
@property
|
|
559
548
|
def pipeline_segment(self):
|
|
560
|
-
"""
|
|
561
|
-
`pipeline_segment` represents the pipeline segment of current Cell.
|
|
562
|
-
"""
|
|
563
549
|
return self._pipeline_segment
|
|
564
550
|
|
|
565
551
|
@pipeline_segment.setter
|
|
566
552
|
def pipeline_segment(self, value):
|
|
567
|
-
"""
|
|
568
|
-
Set the `pipeline_segment` of a Cell. Only effective in zero_bubble_v scheduler.
|
|
569
|
-
|
|
570
|
-
Args:
|
|
571
|
-
value (int): The pipeline segment of a parameter.
|
|
572
|
-
|
|
573
|
-
Raises:
|
|
574
|
-
TypeError: If `value` is not int type or is a bool type.
|
|
575
|
-
ValueError: If `value` is not a positive integer.
|
|
576
|
-
"""
|
|
577
553
|
if not isinstance(value, int) or isinstance(value, bool):
|
|
578
554
|
raise TypeError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' "
|
|
579
555
|
"must be int type, but got type : {}".format(type(value)))
|
|
@@ -1051,13 +1027,12 @@ class Cell(Cell_):
|
|
|
1051
1027
|
if self._forward_pre_hook:
|
|
1052
1028
|
args, kwargs = self._run_forward_pre_hook(args, kwargs)
|
|
1053
1029
|
|
|
1054
|
-
if self._backward_hook:
|
|
1055
|
-
args = self._cell_backward_hook(args)
|
|
1056
|
-
|
|
1057
1030
|
if self._shard_fn is not None:
|
|
1058
1031
|
output = self._shard_fn(*args, **kwargs)
|
|
1059
1032
|
elif _pynative_executor.requires_grad():
|
|
1060
|
-
if self.
|
|
1033
|
+
if self._backward_hook:
|
|
1034
|
+
output = self._backward_hook_construct(*args, **kwargs)
|
|
1035
|
+
elif self._recompute_cell is not None:
|
|
1061
1036
|
output = self._recompute_cell(*args, **kwargs)
|
|
1062
1037
|
elif self.has_bprop:
|
|
1063
1038
|
output = self._call_custom_bprop(*args, **kwargs)
|
|
@@ -1069,11 +1044,8 @@ class Cell(Cell_):
|
|
|
1069
1044
|
if self._forward_hook:
|
|
1070
1045
|
output = self._run_forward_hook(args, kwargs, output)
|
|
1071
1046
|
|
|
1072
|
-
if self.
|
|
1073
|
-
output = self.
|
|
1074
|
-
|
|
1075
|
-
if self._backward_pre_hook:
|
|
1076
|
-
output = self._cell_backward_pre_hook(output)
|
|
1047
|
+
if self._backward_pre_hook and _pynative_executor.requires_grad():
|
|
1048
|
+
output = self._run_backward_pre_hook(output)
|
|
1077
1049
|
|
|
1078
1050
|
return output
|
|
1079
1051
|
|
|
@@ -1108,6 +1080,23 @@ class Cell(Cell_):
|
|
|
1108
1080
|
f"{default_args} default argument, total {positional_args + default_args}, "
|
|
1109
1081
|
f"but got {len(args)}.")
|
|
1110
1082
|
|
|
1083
|
+
# pylint: disable=E0203
|
|
1084
|
+
def _hook_fn_registered(self):
|
|
1085
|
+
'''Hook function in graph mode'''
|
|
1086
|
+
# Check super().__init__() in graph mode.
|
|
1087
|
+
try:
|
|
1088
|
+
if self._forward_pre_hook or self._forward_hook or self._backward_pre_hook or self._backward_hook:
|
|
1089
|
+
return True
|
|
1090
|
+
except AttributeError as e:
|
|
1091
|
+
raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
|
|
1092
|
+
f"Please use 'super().__init__()'.") from e
|
|
1093
|
+
if not self._is_recursion_hook:
|
|
1094
|
+
self._is_recursion_hook = True
|
|
1095
|
+
for cell in self.cells():
|
|
1096
|
+
if cell._hook_fn_registered():
|
|
1097
|
+
return True
|
|
1098
|
+
return False
|
|
1099
|
+
|
|
1111
1100
|
def _get_prims_recursively(self):
|
|
1112
1101
|
all_prims = list()
|
|
1113
1102
|
for _, value in self._primitives.items():
|
|
@@ -1133,6 +1122,9 @@ class Cell(Cell_):
|
|
|
1133
1122
|
>>> net = nn.Dense(3, 4)
|
|
1134
1123
|
>>> net.set_data_parallel()
|
|
1135
1124
|
"""
|
|
1125
|
+
if context._get_mode() == context.PYNATIVE_MODE:
|
|
1126
|
+
raise ValueError("set_data_parallel: does not support PyNative mode.")
|
|
1127
|
+
|
|
1136
1128
|
all_prims = self._get_prims_recursively()
|
|
1137
1129
|
for prim in all_prims:
|
|
1138
1130
|
prim.add_prim_attr("strategy_gen_mode", "data_parallel")
|
|
@@ -1211,6 +1203,8 @@ class Cell(Cell_):
|
|
|
1211
1203
|
... out = self.blocks[i](out)
|
|
1212
1204
|
... return out
|
|
1213
1205
|
"""
|
|
1206
|
+
if context._get_mode() == context.PYNATIVE_MODE:
|
|
1207
|
+
raise ValueError("The Cell offload does not support PyNative mode now.")
|
|
1214
1208
|
if isinstance(backward_prefetch, str):
|
|
1215
1209
|
Validator.check_string(backward_prefetch, ['Auto'], 'backward_prefetch', self.cls_name)
|
|
1216
1210
|
else:
|
|
@@ -1218,10 +1212,11 @@ class Cell(Cell_):
|
|
|
1218
1212
|
for prim in self._get_prims_recursively():
|
|
1219
1213
|
prim._offload(backward_prefetch=backward_prefetch)
|
|
1220
1214
|
|
|
1221
|
-
def shard(self, in_strategy, out_strategy=None, parameter_plan=None):
|
|
1215
|
+
def shard(self, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
|
|
1222
1216
|
"""
|
|
1223
1217
|
Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be
|
|
1224
|
-
generated by sharding propagation. In
|
|
1218
|
+
generated by sharding propagation. In PyNative mode, use this method to specify a Cell for distributed
|
|
1219
|
+
execution in graph mode. In Graph mode, use this method to specify distribution strategy for a Cell,
|
|
1225
1220
|
strategy for others will be set by sharding propagation.
|
|
1226
1221
|
in_strategy and out_strategy define the input and output layout respectively.
|
|
1227
1222
|
in_strategy/out_strategy should be a tuple, each element of which corresponds to the desired layout of
|
|
@@ -1233,14 +1228,11 @@ class Cell(Cell_):
|
|
|
1233
1228
|
In other parallel modes, strategies set here will be ignored.
|
|
1234
1229
|
- If the input contain Parameter, its strategy should be set in `in_strategy`.
|
|
1235
1230
|
|
|
1236
|
-
.. warning::
|
|
1237
|
-
The method is currently not supported in PyNative mode.
|
|
1238
|
-
|
|
1239
1231
|
Args:
|
|
1240
1232
|
in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple. Tuple
|
|
1241
1233
|
defines the layout of the corresponding input.
|
|
1242
1234
|
out_strategy (Union[None, tuple]): Define the layout of outputs similar with in_strategy.
|
|
1243
|
-
Default: ``None`` .
|
|
1235
|
+
It is not in use right now. Default: ``None`` .
|
|
1244
1236
|
parameter_plan (Union[dict, None]): Define the layout for the specified parameters. Each element in dict
|
|
1245
1237
|
defines the layout of the parameter like "param_name: layout".
|
|
1246
1238
|
The key is a parameter name of type 'str'.
|
|
@@ -1248,6 +1240,14 @@ class Cell(Cell_):
|
|
|
1248
1240
|
If the parameter name is incorrect or the corresponding parameter
|
|
1249
1241
|
has been set, the parameter setting will be ignored.
|
|
1250
1242
|
Default: ``None`` .
|
|
1243
|
+
device (str): Select a certain device target. It is not in use right now.
|
|
1244
|
+
Support [ ``"CPU"`` , ``"GPU"`` , ``"Ascend"`` ]. Default: ``"Ascend"`` .
|
|
1245
|
+
level (int): Option for parallel strategy infer algorithm, namely the object function, maximize computation
|
|
1246
|
+
over communication ratio, maximize speed performance, minimize memory usage etc. It is not in
|
|
1247
|
+
use right now. Support [ ``"0"`` , ``"1"`` , ``"2"`` ]. Default: ``0`` .
|
|
1248
|
+
|
|
1249
|
+
Returns:
|
|
1250
|
+
Function, return the cell construct function that will be executed under auto parallel process.
|
|
1251
1251
|
|
|
1252
1252
|
Examples:
|
|
1253
1253
|
>>> import mindspore.nn as nn
|
|
@@ -1265,34 +1265,19 @@ class Cell(Cell_):
|
|
|
1265
1265
|
... def __init__(self):
|
|
1266
1266
|
... self.block1 = Block()
|
|
1267
1267
|
... self.block2 = Block()
|
|
1268
|
-
... self.block2.shard(in_strategy=((2, 1),),
|
|
1268
|
+
... self.block2_shard = self.block2.shard(in_strategy=((2, 1),),
|
|
1269
|
+
... parameter_plan={'self.block2.shard.dense1.weight': (4, 1)})
|
|
1269
1270
|
... def construct(self, x):
|
|
1270
1271
|
... x = self.block1(x)
|
|
1271
|
-
... x = self.
|
|
1272
|
+
... x = self.block2_shard(x)
|
|
1272
1273
|
... return x
|
|
1273
1274
|
"""
|
|
1274
1275
|
if ms.communication.management.get_group_size() == 1:
|
|
1275
|
-
return
|
|
1276
|
-
|
|
1276
|
+
return self
|
|
1277
1277
|
shard_fn = Shard()
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
msg = (
|
|
1282
|
-
"For '%s', 'Shard' has been configured more than once. "
|
|
1283
|
-
"The existing in_strategy is %s and the existing out_strategy is %s. "
|
|
1284
|
-
"The new in_strategy %s and out_strategy %s may not take effect. "
|
|
1285
|
-
"It is recommended to configure 'Shard' only once."
|
|
1286
|
-
) % (
|
|
1287
|
-
self._cell_tag,
|
|
1288
|
-
self._in_strategy, # pylint: disable=E0203
|
|
1289
|
-
self._out_strategy, # pylint: disable=E0203
|
|
1290
|
-
shard_fn.in_strategy,
|
|
1291
|
-
shard_fn.out_strategy,
|
|
1292
|
-
)
|
|
1293
|
-
logger.warning(msg)
|
|
1294
|
-
self._in_strategy = shard_fn.in_strategy
|
|
1295
|
-
self._out_strategy = shard_fn.out_strategy
|
|
1278
|
+
fn = shard_fn(self, in_strategy, out_strategy, parameter_plan, device, level)
|
|
1279
|
+
self._shard_fn = fn
|
|
1280
|
+
return fn
|
|
1296
1281
|
|
|
1297
1282
|
def _init_check(self):
|
|
1298
1283
|
for param in self.get_parameters(expand=False):
|
|
@@ -1301,13 +1286,9 @@ class Cell(Cell_):
|
|
|
1301
1286
|
self._init_flag = True
|
|
1302
1287
|
|
|
1303
1288
|
def _self_check(self):
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
self._is_check_and_refresh = True
|
|
1308
|
-
except AttributeError as e:
|
|
1309
|
-
raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
|
|
1310
|
-
f"Please use 'super().__init__()'.") from e
|
|
1289
|
+
if not self._is_check_and_refresh:
|
|
1290
|
+
self.check_names_and_refresh_name()
|
|
1291
|
+
self._is_check_and_refresh = True
|
|
1311
1292
|
|
|
1312
1293
|
def _predict(self, *args, **kwargs):
|
|
1313
1294
|
'''Graph executor for predict'''
|
|
@@ -1328,7 +1309,6 @@ class Cell(Cell_):
|
|
|
1328
1309
|
def __call__(self, *args, **kwargs):
|
|
1329
1310
|
# Run in Graph mode.
|
|
1330
1311
|
if context._get_mode() == context.GRAPH_MODE and os.getenv("MS_JIT") != '0':
|
|
1331
|
-
self._compiled = True
|
|
1332
1312
|
if kwargs:
|
|
1333
1313
|
bound_arguments = self._construct_sig.bind(*args, **kwargs)
|
|
1334
1314
|
bound_arguments.apply_defaults()
|
|
@@ -1339,8 +1319,11 @@ class Cell(Cell_):
|
|
|
1339
1319
|
if predict_compiled:
|
|
1340
1320
|
return res
|
|
1341
1321
|
self._check_construct_args(*args)
|
|
1322
|
+
|
|
1323
|
+
if self._hook_fn_registered():
|
|
1324
|
+
logger.warning(f"For 'Cell', it's not support hook function in graph mode. If you want to use hook "
|
|
1325
|
+
f"function, please use context.set_context to set pynative mode.")
|
|
1342
1326
|
self._self_check()
|
|
1343
|
-
self.__compile_cell_hook__ = True
|
|
1344
1327
|
out = self.compile_and_run(*args, **kwargs)
|
|
1345
1328
|
return out
|
|
1346
1329
|
|
|
@@ -1438,7 +1421,16 @@ class Cell(Cell_):
|
|
|
1438
1421
|
exist_names.add(item.name)
|
|
1439
1422
|
self.insert_param_to_cell(item.name, item, check_name_contain_dot=False)
|
|
1440
1423
|
|
|
1441
|
-
|
|
1424
|
+
if context._get_mode() == context.PYNATIVE_MODE:
|
|
1425
|
+
if name in self.__dict__:
|
|
1426
|
+
del self.__dict__[name]
|
|
1427
|
+
params = self.__dict__.get('_params')
|
|
1428
|
+
if name in params:
|
|
1429
|
+
del params[name]
|
|
1430
|
+
params_list = self.__dict__.get('_params_list')
|
|
1431
|
+
params_list[name] = value
|
|
1432
|
+
else:
|
|
1433
|
+
object.__setattr__(self, name, value)
|
|
1442
1434
|
|
|
1443
1435
|
def _set_attr_for_parameter_in_list_or_tuple(self, name, value):
|
|
1444
1436
|
"""Set attr for parameter in list or tuple."""
|
|
@@ -1617,6 +1609,8 @@ class Cell(Cell_):
|
|
|
1617
1609
|
_pynative_executor.set_dynamic_input(self, *self._dynamic_shape_inputs)
|
|
1618
1610
|
else:
|
|
1619
1611
|
self._check_construct_args(*inputs)
|
|
1612
|
+
# TODO(tronzhang): It may error for no actually args here. So just set in fullmode,
|
|
1613
|
+
# which means that incremental mode is lacking dynamic input.
|
|
1620
1614
|
else:
|
|
1621
1615
|
self._dynamic_shape_inputs = _process_dyn_args(self.construct, kwargs)
|
|
1622
1616
|
|
|
@@ -2604,7 +2598,6 @@ class Cell(Cell_):
|
|
|
2604
2598
|
raise ValueError(f"Negative 'fusion_size' {fusion_size} is invalid.")
|
|
2605
2599
|
Tensor._flatten_tensors(self.trainable_params(), fusion_size) # pylint: disable=W0212
|
|
2606
2600
|
|
|
2607
|
-
@jit_forbidden_register
|
|
2608
2601
|
def register_forward_pre_hook(self, hook_fn, with_kwargs=False):
|
|
2609
2602
|
"""
|
|
2610
2603
|
Register forward pre hook function for Cell object.
|
|
@@ -2624,6 +2617,7 @@ class Cell(Cell_):
|
|
|
2624
2617
|
`with_kwargs` is ``True`` .
|
|
2625
2618
|
|
|
2626
2619
|
Note:
|
|
2620
|
+
- The feature does not take effect in graph mode or in PyNative mode with functions decorated by jit.
|
|
2627
2621
|
- The `hook_fn` can modify the forward inputs by returning new inputs. If `with_kwargs` is ``Flase`` , a
|
|
2628
2622
|
single value (whick will be wrapped into a tuple unless already a tuple) or a tuple of args should be
|
|
2629
2623
|
returned. If `with_kwargs` is ``True`` , both `args` and `kwargs` should be returned.
|
|
@@ -2674,15 +2668,15 @@ class Cell(Cell_):
|
|
|
2674
2668
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
|
|
2675
2669
|
value= [ 2.00000000e+00]))
|
|
2676
2670
|
"""
|
|
2671
|
+
if context._get_mode() == context.GRAPH_MODE:
|
|
2672
|
+
return HookHandle()
|
|
2677
2673
|
check_hook_fn(hook_fn)
|
|
2678
2674
|
handle = HookHandle(self._forward_pre_hook, extra_dict=self._forward_pre_hook_with_kwargs)
|
|
2679
2675
|
self._forward_pre_hook[handle.handle_id] = hook_fn
|
|
2680
2676
|
if with_kwargs:
|
|
2681
2677
|
self._forward_pre_hook_with_kwargs[handle.handle_id] = True
|
|
2682
|
-
_update_hook_version()
|
|
2683
2678
|
return handle
|
|
2684
2679
|
|
|
2685
|
-
@jit_forbidden_register
|
|
2686
2680
|
def _run_forward_pre_hook(self, args, kwargs):
|
|
2687
2681
|
"""
|
|
2688
2682
|
Running forward pre hook function registered on Cell object.
|
|
@@ -2706,35 +2700,6 @@ class Cell(Cell_):
|
|
|
2706
2700
|
args = ret
|
|
2707
2701
|
return args, kwargs
|
|
2708
2702
|
|
|
2709
|
-
def _jit_forward_pre_hook(self, inputs):
|
|
2710
|
-
"""
|
|
2711
|
-
Compile forward pre hook function registered on Cell object.
|
|
2712
|
-
|
|
2713
|
-
Args:
|
|
2714
|
-
inputs: The input objects of cell object.
|
|
2715
|
-
|
|
2716
|
-
Returns:
|
|
2717
|
-
- **outputs** - New input objects or none.
|
|
2718
|
-
|
|
2719
|
-
Supported Platforms:
|
|
2720
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
2721
|
-
"""
|
|
2722
|
-
forward_pre_hook_inputs = inputs
|
|
2723
|
-
for fn in self._forward_pre_hook.values():
|
|
2724
|
-
ret = fn(self, forward_pre_hook_inputs)
|
|
2725
|
-
if ret is not None:
|
|
2726
|
-
if not isinstance(ret, tuple):
|
|
2727
|
-
forward_pre_hook_inputs = (ret,)
|
|
2728
|
-
else:
|
|
2729
|
-
forward_pre_hook_inputs = ret
|
|
2730
|
-
|
|
2731
|
-
if len(forward_pre_hook_inputs) != len(inputs):
|
|
2732
|
-
raise TypeError(
|
|
2733
|
-
"The forward pre hook return value size is {} not equal to input size {}".format(
|
|
2734
|
-
len(forward_pre_hook_inputs), len(inputs)))
|
|
2735
|
-
return forward_pre_hook_inputs
|
|
2736
|
-
|
|
2737
|
-
@jit_forbidden_register
|
|
2738
2703
|
def register_forward_hook(self, hook_fn, with_kwargs=False):
|
|
2739
2704
|
"""
|
|
2740
2705
|
Register forward hook function for Cell object.
|
|
@@ -2755,6 +2720,7 @@ class Cell(Cell_):
|
|
|
2755
2720
|
- `output`: Output generated by the `construct` function.
|
|
2756
2721
|
|
|
2757
2722
|
Note:
|
|
2723
|
+
- The feature does not take effect in graph mode or in PyNative mode with functions decorated by jit.
|
|
2758
2724
|
- The `hook_fn` can modify the forward outputs by returning new outputs.
|
|
2759
2725
|
- In order to prevent running failed when switching to graph mode, it is not recommended to call it in the
|
|
2760
2726
|
`construct` function of Cell object.
|
|
@@ -2807,44 +2773,15 @@ class Cell(Cell_):
|
|
|
2807
2773
|
"""
|
|
2808
2774
|
if self.has_bprop:
|
|
2809
2775
|
return HookHandle()
|
|
2776
|
+
if context._get_mode() == context.GRAPH_MODE:
|
|
2777
|
+
return HookHandle()
|
|
2810
2778
|
check_hook_fn(hook_fn)
|
|
2811
2779
|
handle = HookHandle(self._forward_hook, extra_dict=self._forward_hook_with_kwargs)
|
|
2812
2780
|
self._forward_hook[handle.handle_id] = hook_fn
|
|
2813
2781
|
if with_kwargs:
|
|
2814
2782
|
self._forward_hook_with_kwargs[handle.handle_id] = True
|
|
2815
|
-
_update_hook_version()
|
|
2816
2783
|
return handle
|
|
2817
2784
|
|
|
2818
|
-
def _jit_forward_hook(self, inputs, output):
|
|
2819
|
-
"""
|
|
2820
|
-
Compile forward hook function registered on Cell object.
|
|
2821
|
-
|
|
2822
|
-
Args:
|
|
2823
|
-
inputs: The input objects of Cell object.
|
|
2824
|
-
output: The output object of Cell object.
|
|
2825
|
-
|
|
2826
|
-
Returns:
|
|
2827
|
-
- **output** - New output object or none.
|
|
2828
|
-
|
|
2829
|
-
Supported Platforms:
|
|
2830
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
2831
|
-
"""
|
|
2832
|
-
forward_hook_output = output
|
|
2833
|
-
for fn in self._forward_hook.values():
|
|
2834
|
-
ret = fn(self, inputs, forward_hook_output)
|
|
2835
|
-
if ret is not None:
|
|
2836
|
-
forward_hook_output = ret
|
|
2837
|
-
|
|
2838
|
-
if isinstance(output, tuple):
|
|
2839
|
-
if not isinstance(forward_hook_output, tuple):
|
|
2840
|
-
forward_hook_output = (forward_hook_output,)
|
|
2841
|
-
if len(forward_hook_output) != len(output):
|
|
2842
|
-
raise TypeError(
|
|
2843
|
-
"The forward hook return value size is {} not equal to output size {}".format(
|
|
2844
|
-
len(forward_hook_output), len(output)))
|
|
2845
|
-
return forward_hook_output
|
|
2846
|
-
|
|
2847
|
-
@jit_forbidden_register
|
|
2848
2785
|
def _run_forward_hook(self, args, kwargs, output):
|
|
2849
2786
|
"""
|
|
2850
2787
|
Running forward hook function registered on Cell object.
|
|
@@ -2858,12 +2795,12 @@ class Cell(Cell_):
|
|
|
2858
2795
|
output = ret
|
|
2859
2796
|
return output
|
|
2860
2797
|
|
|
2861
|
-
@jit_forbidden_register
|
|
2862
2798
|
def register_backward_pre_hook(self, hook_fn):
|
|
2863
2799
|
"""
|
|
2864
2800
|
Register the backward pre hook function.
|
|
2865
2801
|
|
|
2866
2802
|
Note:
|
|
2803
|
+
- The `register_backward_pre_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
|
|
2867
2804
|
- The 'hook_fn' must be defined as the following code.
|
|
2868
2805
|
`cell` is the Cell object. `grad_output` is the gradient passed to the Cell.
|
|
2869
2806
|
- The 'hook_fn' should have the following signature:
|
|
@@ -2912,17 +2849,44 @@ class Cell(Cell_):
|
|
|
2912
2849
|
>>> print(output)
|
|
2913
2850
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
|
|
2914
2851
|
"""
|
|
2852
|
+
if context._get_mode() == context.GRAPH_MODE:
|
|
2853
|
+
return HookHandle()
|
|
2915
2854
|
check_hook_fn(hook_fn)
|
|
2916
|
-
handle = HookHandle(self._backward_pre_hook
|
|
2855
|
+
handle = HookHandle(self._backward_pre_hook)
|
|
2917
2856
|
self._backward_pre_hook[handle.handle_id] = hook_fn
|
|
2918
|
-
if self._cell_backward_pre_hook is None:
|
|
2857
|
+
if self._cell_backward_pre_hook is None:
|
|
2919
2858
|
# Generate a CellBackwardHook prim, and add function for it
|
|
2920
2859
|
self._cell_backward_pre_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")",
|
|
2921
2860
|
self, self._backward_pre_hook)
|
|
2922
2861
|
self._cell_backward_pre_hook.register_backward_pre_hook()
|
|
2923
|
-
_update_hook_version()
|
|
2924
2862
|
return handle
|
|
2925
2863
|
|
|
2864
|
+
def _run_backward_pre_hook(self, outputs):
|
|
2865
|
+
"""
|
|
2866
|
+
Running backward pre hook function registered on Cell object.
|
|
2867
|
+
|
|
2868
|
+
Args:
|
|
2869
|
+
outputs: The output objects of cell object.
|
|
2870
|
+
|
|
2871
|
+
Returns:
|
|
2872
|
+
- **outputs** - New backward gradient or None.
|
|
2873
|
+
|
|
2874
|
+
Supported Platforms:
|
|
2875
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2876
|
+
"""
|
|
2877
|
+
if isinstance(outputs, tuple):
|
|
2878
|
+
ret = self._cell_backward_pre_hook(*outputs)
|
|
2879
|
+
else:
|
|
2880
|
+
ret = self._cell_backward_pre_hook(outputs)
|
|
2881
|
+
if isinstance(outputs, tuple):
|
|
2882
|
+
if len(outputs) == 1:
|
|
2883
|
+
ret = (ret,)
|
|
2884
|
+
if len(ret) != len(outputs):
|
|
2885
|
+
raise TypeError(
|
|
2886
|
+
"The backward pre hook return value size is {} not equal to output size {}".format(
|
|
2887
|
+
len(ret), len(outputs)))
|
|
2888
|
+
return ret
|
|
2889
|
+
|
|
2926
2890
|
def get_extra_state(self) -> Any:
|
|
2927
2891
|
"""Return any extra state to include in the cell's state_dict.
|
|
2928
2892
|
|
|
@@ -2975,8 +2939,9 @@ class Cell(Cell_):
|
|
|
2975
2939
|
A handle that can be used to remove the added hook by calling
|
|
2976
2940
|
`handle.remove()`.
|
|
2977
2941
|
"""
|
|
2978
|
-
|
|
2979
|
-
|
|
2942
|
+
from mindspore.utils.hooks import _RemovableHandle
|
|
2943
|
+
handle = _RemovableHandle(self._state_dict_hooks)
|
|
2944
|
+
self._state_dict_hooks[handle.id] = hook
|
|
2980
2945
|
return handle
|
|
2981
2946
|
|
|
2982
2947
|
@jit_forbidden_register
|
|
@@ -3022,8 +2987,9 @@ class Cell(Cell_):
|
|
|
3022
2987
|
>>> print("extra_param" in net_state_dict)
|
|
3023
2988
|
True
|
|
3024
2989
|
"""
|
|
3025
|
-
|
|
3026
|
-
|
|
2990
|
+
from mindspore.utils.hooks import _RemovableHandle
|
|
2991
|
+
handle = _RemovableHandle(self._state_dict_pre_hooks)
|
|
2992
|
+
self._state_dict_pre_hooks[handle.id] = hook
|
|
3027
2993
|
return handle
|
|
3028
2994
|
|
|
3029
2995
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
@@ -3169,8 +3135,9 @@ class Cell(Cell_):
|
|
|
3169
3135
|
A handle that can be used to remove the added hook by calling
|
|
3170
3136
|
`handle.remove()`.
|
|
3171
3137
|
"""
|
|
3172
|
-
|
|
3173
|
-
|
|
3138
|
+
from mindspore.utils.hooks import _RemovableHandle
|
|
3139
|
+
handle = _RemovableHandle(self._load_state_dict_pre_hooks)
|
|
3140
|
+
self._load_state_dict_pre_hooks[handle.id] = hook
|
|
3174
3141
|
return handle
|
|
3175
3142
|
|
|
3176
3143
|
@jit_forbidden_register
|
|
@@ -3202,8 +3169,9 @@ class Cell(Cell_):
|
|
|
3202
3169
|
A handle that can be used to remove the added hook by calling
|
|
3203
3170
|
`handle.remove()`.
|
|
3204
3171
|
"""
|
|
3205
|
-
|
|
3206
|
-
|
|
3172
|
+
from mindspore.utils.hooks import _RemovableHandle
|
|
3173
|
+
handle = _RemovableHandle(self._load_state_dict_post_hooks)
|
|
3174
|
+
self._load_state_dict_post_hooks[handle.id] = hook
|
|
3207
3175
|
return handle
|
|
3208
3176
|
|
|
3209
3177
|
def _load_from_state_dict(
|
|
@@ -3439,12 +3407,12 @@ class Cell(Cell_):
|
|
|
3439
3407
|
)
|
|
3440
3408
|
return _IncompatibleKeys(missing_keys, unexpected_keys)
|
|
3441
3409
|
|
|
3442
|
-
@jit_forbidden_register
|
|
3443
3410
|
def register_backward_hook(self, hook_fn):
|
|
3444
3411
|
"""
|
|
3445
3412
|
Register the backward hook function.
|
|
3446
3413
|
|
|
3447
3414
|
Note:
|
|
3415
|
+
- The `register_backward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
|
|
3448
3416
|
- The 'hook_fn' must be defined as the following code.
|
|
3449
3417
|
`cell` is the registered Cell object. `grad_input` is the gradient computed and passed to
|
|
3450
3418
|
the next Cell or primitive, which can be return a new gradient or None. `grad_output` is the gradient
|
|
@@ -3496,17 +3464,65 @@ class Cell(Cell_):
|
|
|
3496
3464
|
>>> print(output)
|
|
3497
3465
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
|
|
3498
3466
|
"""
|
|
3467
|
+
if context._get_mode() == context.GRAPH_MODE:
|
|
3468
|
+
return HookHandle()
|
|
3499
3469
|
check_hook_fn(hook_fn)
|
|
3500
|
-
handle = HookHandle(self._backward_hook
|
|
3470
|
+
handle = HookHandle(self._backward_hook)
|
|
3501
3471
|
self._backward_hook[handle.handle_id] = hook_fn
|
|
3502
|
-
if self._cell_backward_hook is None:
|
|
3472
|
+
if self._cell_backward_hook is None:
|
|
3503
3473
|
# Generate a CellBackwardHook prim, and add function for it
|
|
3504
3474
|
self._cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")",
|
|
3505
3475
|
self, self._backward_hook)
|
|
3506
3476
|
self._cell_backward_hook.register_backward_hook()
|
|
3507
|
-
_update_hook_version()
|
|
3508
3477
|
return handle
|
|
3509
3478
|
|
|
3479
|
+
def _backward_hook_construct(self, *inputs, **kwargs):
|
|
3480
|
+
"""
|
|
3481
|
+
Backward hook construct method to replace original construct method.
|
|
3482
|
+
|
|
3483
|
+
Args:
|
|
3484
|
+
inputs: The input objects of Cell object.
|
|
3485
|
+
kwargs (dict): Dictionary of variable keyword parameters.
|
|
3486
|
+
|
|
3487
|
+
Returns:
|
|
3488
|
+
- **outputs** - The output objects of Cell object.
|
|
3489
|
+
|
|
3490
|
+
Supported Platforms:
|
|
3491
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
3492
|
+
"""
|
|
3493
|
+
# cell_backward_hook has CellBackwardHook op, so keep input args as they are.
|
|
3494
|
+
outputs = self._cell_backward_hook(*inputs)
|
|
3495
|
+
# If the inputs have more than two args, the outputs will also have more than two args and will be wrapped into
|
|
3496
|
+
# a tuple, so need to do unwrapping. If inputs is empty, we also need to unwrap it.
|
|
3497
|
+
# Because when output of runop method is one, it will not wrap a tuple, we need not unwrap it.
|
|
3498
|
+
is_need_unwrap = False
|
|
3499
|
+
if isinstance(outputs, tuple) and len(inputs) != 1:
|
|
3500
|
+
is_need_unwrap = True
|
|
3501
|
+
|
|
3502
|
+
if self._recompute_cell is not None:
|
|
3503
|
+
if is_need_unwrap:
|
|
3504
|
+
outputs = self._recompute_cell(*outputs, **kwargs)
|
|
3505
|
+
else:
|
|
3506
|
+
outputs = self._recompute_cell(outputs, **kwargs)
|
|
3507
|
+
elif self.has_bprop:
|
|
3508
|
+
if is_need_unwrap:
|
|
3509
|
+
outputs = self._call_custom_bprop(*outputs, **kwargs)
|
|
3510
|
+
else:
|
|
3511
|
+
outputs = self._call_custom_bprop(outputs, **kwargs)
|
|
3512
|
+
else:
|
|
3513
|
+
if is_need_unwrap:
|
|
3514
|
+
outputs = self.construct(*outputs, **kwargs)
|
|
3515
|
+
else:
|
|
3516
|
+
outputs = self.construct(outputs, **kwargs)
|
|
3517
|
+
if isinstance(outputs, tuple):
|
|
3518
|
+
new_outputs = self._cell_backward_hook(*outputs)
|
|
3519
|
+
else:
|
|
3520
|
+
new_outputs = self._cell_backward_hook(outputs)
|
|
3521
|
+
# if outputs is (X,) and new_outpus is X
|
|
3522
|
+
if isinstance(outputs, tuple) and len(outputs) == 1:
|
|
3523
|
+
new_outputs = (new_outputs,)
|
|
3524
|
+
return new_outputs
|
|
3525
|
+
|
|
3510
3526
|
def set_param_ps(self, recurse=True, init_in_server=False):
|
|
3511
3527
|
"""
|
|
3512
3528
|
Set whether the trainable parameters are updated by parameter server and whether the
|
|
@@ -3585,7 +3601,7 @@ class Cell(Cell_):
|
|
|
3585
3601
|
"""
|
|
3586
3602
|
Validator.check_bool(mode)
|
|
3587
3603
|
Validator.check_bool(output_recompute)
|
|
3588
|
-
if not self._has_config_recompute:
|
|
3604
|
+
if not self._has_config_recompute:
|
|
3589
3605
|
self._has_config_recompute = True
|
|
3590
3606
|
else:
|
|
3591
3607
|
logger.info("The recompute interface can be configured only once."
|
|
@@ -3628,7 +3644,7 @@ class Cell(Cell_):
|
|
|
3628
3644
|
introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode.
|
|
3629
3645
|
Default: ``False`` .
|
|
3630
3646
|
"""
|
|
3631
|
-
if context.
|
|
3647
|
+
if context.get_context("mode") == context.PYNATIVE_MODE:
|
|
3632
3648
|
self._recompute_cell = recompute_registry.get()(self.construct)
|
|
3633
3649
|
self._recompute()
|
|
3634
3650
|
if 'mp_comm_recompute' in kwargs.keys():
|
|
@@ -3689,64 +3705,6 @@ class Cell(Cell_):
|
|
|
3689
3705
|
"""
|
|
3690
3706
|
self._jit_graph_name = key
|
|
3691
3707
|
|
|
3692
|
-
def _jit_backward_pre_hook(self, grad_output):
|
|
3693
|
-
new_grad_output = grad_output
|
|
3694
|
-
if not isinstance(grad_output, tuple):
|
|
3695
|
-
new_grad_output = (grad_output,)
|
|
3696
|
-
|
|
3697
|
-
for fn in self._backward_pre_hook.values():
|
|
3698
|
-
ret = fn(self, new_grad_output)
|
|
3699
|
-
if ret is not None:
|
|
3700
|
-
if not isinstance(ret, tuple):
|
|
3701
|
-
output = (ret,)
|
|
3702
|
-
else:
|
|
3703
|
-
output = ret
|
|
3704
|
-
else:
|
|
3705
|
-
output = ops.Depend()(new_grad_output, ret)
|
|
3706
|
-
new_grad_output = output
|
|
3707
|
-
|
|
3708
|
-
if not isinstance(grad_output, tuple):
|
|
3709
|
-
if len(new_grad_output) == 1:
|
|
3710
|
-
return new_grad_output[0]
|
|
3711
|
-
raise TypeError(
|
|
3712
|
-
"The backward pre hook return value size is {} not equal to input size 1".format(
|
|
3713
|
-
len(new_grad_output)))
|
|
3714
|
-
|
|
3715
|
-
if len(new_grad_output) != len(grad_output):
|
|
3716
|
-
raise TypeError(
|
|
3717
|
-
"The backward pre hook return value size is {} not equal to input size {}".format(
|
|
3718
|
-
len(new_grad_output), len(grad_output)))
|
|
3719
|
-
|
|
3720
|
-
return new_grad_output
|
|
3721
|
-
|
|
3722
|
-
def _jit_backward_hook(self, grad_input, grad_output):
|
|
3723
|
-
backward_hook_input = grad_input
|
|
3724
|
-
backward_hook_output = grad_output
|
|
3725
|
-
if not isinstance(grad_input, tuple):
|
|
3726
|
-
backward_hook_input = (grad_input,)
|
|
3727
|
-
if not isinstance(grad_output, tuple):
|
|
3728
|
-
backward_hook_output = (grad_output,)
|
|
3729
|
-
|
|
3730
|
-
for fn in self._backward_hook.values():
|
|
3731
|
-
ret = fn(self, backward_hook_input, backward_hook_output)
|
|
3732
|
-
if ret is not None:
|
|
3733
|
-
if not isinstance(ret, tuple):
|
|
3734
|
-
output = (ret,)
|
|
3735
|
-
else:
|
|
3736
|
-
output = ret
|
|
3737
|
-
else:
|
|
3738
|
-
output = ops.Depend()(backward_hook_input, ret)
|
|
3739
|
-
|
|
3740
|
-
backward_hook_input = output
|
|
3741
|
-
|
|
3742
|
-
if not isinstance(grad_input, tuple):
|
|
3743
|
-
return backward_hook_input[0]
|
|
3744
|
-
|
|
3745
|
-
if len(backward_hook_input) != len(grad_input):
|
|
3746
|
-
raise TypeError(
|
|
3747
|
-
"The backward hook return value size is {} not equal to input size {}".format(
|
|
3748
|
-
len(backward_hook_input), len(grad_input)))
|
|
3749
|
-
return backward_hook_input
|
|
3750
3708
|
|
|
3751
3709
|
class GraphCell(Cell):
|
|
3752
3710
|
"""
|