mindspore 2.7.0__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +8 -1
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +275 -64
- mindspore/common/_utils.py +0 -44
- mindspore/common/api.py +285 -35
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
- mindspore/common/hook_handle.py +60 -0
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/parameter.py +13 -107
- mindspore/common/recompute.py +4 -11
- mindspore/common/tensor.py +16 -169
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +5 -85
- mindspore/dataset/engine/datasets.py +8 -4
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dnnl.dll +0 -0
- mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +1 -1
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/distributed.py +182 -62
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +4 -4
- mindspore/mint/nn/layer/normalization.py +8 -13
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +16 -66
- mindspore/nn/layer/basic.py +49 -1
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +31 -124
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +0 -1
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +0 -1
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/wrap/grad_reducer.py +4 -74
- mindspore/numpy/array_creations.py +2 -2
- mindspore/numpy/fft.py +9 -9
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +0 -5
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
- mindspore/ops/auto_generate/gen_extend_func.py +2 -7
- mindspore/ops/auto_generate/gen_ops_def.py +98 -141
- mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +15 -1
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +1 -0
- mindspore/ops/function/array_func.py +14 -12
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +3 -4
- mindspore/ops/function/math_func.py +45 -54
- mindspore/ops/function/nn_func.py +75 -294
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +2 -0
- mindspore/ops/functional_overload.py +354 -18
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +1 -38
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +1 -0
- mindspore/ops/operations/comm_ops.py +94 -2
- mindspore/ops/operations/custom_ops.py +228 -19
- mindspore/ops/operations/debug_ops.py +27 -29
- mindspore/ops/operations/manually_defined/ops_def.py +27 -306
- mindspore/ops/operations/nn_ops.py +2 -2
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +1 -17
- mindspore/ops/tensor_method.py +72 -3
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +1 -4
- mindspore/parallel/_utils.py +29 -6
- mindspore/parallel/checkpoint_transform.py +18 -2
- mindspore/parallel/cluster/process_entity/_api.py +24 -32
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +117 -16
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +38 -2
- mindspore/profiler/common/path_manager.py +56 -24
- mindspore/profiler/common/profiler_context.py +2 -12
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/experimental_config.py +2 -1
- mindspore/profiler/platform/npu_profiler.py +33 -6
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +3 -2
- mindspore/runtime/executor.py +11 -3
- mindspore/runtime/memory.py +112 -0
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +5 -18
- mindspore/train/amp.py +6 -4
- mindspore/train/callback/_checkpoint.py +0 -9
- mindspore/train/callback/_train_fault_tolerance.py +69 -18
- mindspore/train/data_sink.py +1 -5
- mindspore/train/model.py +38 -211
- mindspore/train/serialization.py +126 -387
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +144 -8
- mindspore/version.py +1 -1
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/train/memory_profiling_pb2.py +0 -298
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
mindspore/common/hook_handle.py
CHANGED
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""The removable handle for cell hook function."""
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
import weakref
|
|
18
|
+
from collections import OrderedDict
|
|
18
19
|
from mindspore._c_expression import TensorPy as Tensor_
|
|
19
20
|
from mindspore._check_jit_forbidden_api import jit_forbidden_register
|
|
20
21
|
|
|
@@ -173,3 +174,62 @@ class HookHandle:
|
|
|
173
174
|
extra_dict = self.extra_dict_ref()
|
|
174
175
|
if extra_dict is not None and self.handle_id in extra_dict:
|
|
175
176
|
del extra_dict[self.handle_id]
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def _check_hook_results(pre_res, new_res, hook_fn):
|
|
180
|
+
if not isinstance(new_res, tuple):
|
|
181
|
+
raise RuntimeError(f"hook {hook_fn.__name__} should return a tuple of grad.")
|
|
182
|
+
|
|
183
|
+
new_res_len = len(new_res)
|
|
184
|
+
pre_res_len = len(pre_res)
|
|
185
|
+
if new_res_len != pre_res_len:
|
|
186
|
+
raise RuntimeError(
|
|
187
|
+
f"hook {hook_fn.__name__} returned incorrect length {new_res_len}, expected {pre_res_len}."
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class _HookUtils:
|
|
192
|
+
r"""
|
|
193
|
+
Internal utility class for hook registration and execution.
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
@staticmethod
|
|
197
|
+
def register_hook(hook_dict, hook_fn):
|
|
198
|
+
"""
|
|
199
|
+
Register hook
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
hook_dict (dict): hook dict.
|
|
203
|
+
hook_fn (function): hook function.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
tuple: Updated hook_dict and HookHandle object.
|
|
207
|
+
"""
|
|
208
|
+
if hook_dict is None:
|
|
209
|
+
hook_dict = OrderedDict()
|
|
210
|
+
handle = HookHandle(hook_dict)
|
|
211
|
+
hook_dict[handle.handle_id] = hook_fn
|
|
212
|
+
return hook_dict, handle
|
|
213
|
+
|
|
214
|
+
@staticmethod
|
|
215
|
+
def run_hook(hook_dict, args):
|
|
216
|
+
"""
|
|
217
|
+
Run all hooks in the hook_dict with the given arguments.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
hook_dict (dict): Dictionary of registered hooks.
|
|
221
|
+
args (tuple): Arguments to pass to the hook functions.
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
Modified first argument if any hook returns a new value; otherwise, None.
|
|
225
|
+
"""
|
|
226
|
+
is_modify = False
|
|
227
|
+
args_list = list(args)
|
|
228
|
+
# Note: We create a list from hook_dict.values() to ensure safe iteration.
|
|
229
|
+
for hook_fn in list(hook_dict.values()):
|
|
230
|
+
res = hook_fn(*args_list)
|
|
231
|
+
if res is not None:
|
|
232
|
+
_check_hook_results(args_list[0], res, hook_fn)
|
|
233
|
+
args_list[0] = res
|
|
234
|
+
is_modify = True
|
|
235
|
+
return args_list[0] if is_modify else None
|
mindspore/common/jit_config.py
CHANGED
|
@@ -27,7 +27,11 @@ class JitConfig:
|
|
|
27
27
|
adopt KernelByKernel execution mode.
|
|
28
28
|
- ``"O1"``: Using commonly used optimizations and automatic operator fusion optimizations,
|
|
29
29
|
adopt KernelByKernel execution mode.
|
|
30
|
-
- ``"O2"``:
|
|
30
|
+
- ``"O2"``: Utilizes the GraphEngine, a graph compilation and execution engine within CANN,
|
|
31
|
+
for Ascend model compilation and execution. Note: O2 only supports GRAPH Mode in Ascend,
|
|
32
|
+
only supports whole graph sinking or sub graph sinking in pipeline parallel, and does not support
|
|
33
|
+
dynamic shape scenes. In addition, this mode incurs additional compilation costs and is difficult to
|
|
34
|
+
debug and tune.
|
|
31
35
|
|
|
32
36
|
exc_mode (str, optional): Control the execution mode of the model.
|
|
33
37
|
Supports ["auto", "sink", "no_sink"]. Default: ``"auto"`` .
|
mindspore/common/jit_trace.py
CHANGED
|
@@ -28,6 +28,7 @@ from mindspore._c_expression import TraceRecorder as tr
|
|
|
28
28
|
from mindspore._c_expression import JitExecutor_
|
|
29
29
|
from mindspore._c_expression import TensorPy as Tensor, CSRTensor, COOTensor
|
|
30
30
|
from mindspore._c_expression import typing
|
|
31
|
+
from mindspore.common.jit_config import JitConfig
|
|
31
32
|
|
|
32
33
|
|
|
33
34
|
class TraceJitContext(JitContext):
|
|
@@ -123,19 +124,19 @@ def nested_run(obj, cell, *args):
|
|
|
123
124
|
return file_names, linenos, res
|
|
124
125
|
|
|
125
126
|
|
|
126
|
-
def _jit_trace():
|
|
127
|
+
def _jit_trace(jit_config):
|
|
127
128
|
"""Return the wrapped function for trace mode jit."""
|
|
128
129
|
def wrap_func(fn):
|
|
129
130
|
if hasattr(fn, "construct"):
|
|
130
131
|
if isinstance(fn, ms.nn.Cell):
|
|
131
132
|
# Bound the cell object to get the self arg.
|
|
132
|
-
return types.MethodType(_jit_trace()(fn.construct.__func__), fn)
|
|
133
|
+
return types.MethodType(_jit_trace(jit_config)(fn.construct.__func__), fn)
|
|
133
134
|
if isinstance(fn, type) and issubclass(fn, ms.nn.Cell):
|
|
134
|
-
fn.construct = _jit_trace()(fn.construct)
|
|
135
|
+
fn.construct = _jit_trace(jit_config)(fn.construct)
|
|
135
136
|
return fn
|
|
136
137
|
|
|
137
138
|
if isinstance(fn, types.MethodType):
|
|
138
|
-
return types.MethodType(_jit_trace()(fn.__func__), fn.__self__)
|
|
139
|
+
return types.MethodType(_jit_trace(jit_config)(fn.__func__), fn.__self__)
|
|
139
140
|
|
|
140
141
|
if not isinstance(fn, types.FunctionType):
|
|
141
142
|
logger.warning(f"The fn should be function, method or cell instance/class, but got {fn}")
|
|
@@ -150,6 +151,10 @@ def _jit_trace():
|
|
|
150
151
|
if jit_context():
|
|
151
152
|
return fn(*args, **kwargs)
|
|
152
153
|
# Start trace process.
|
|
154
|
+
if jit_config:
|
|
155
|
+
jit_config_dict = jit_config.jit_config_dict
|
|
156
|
+
else:
|
|
157
|
+
jit_config_dict = JitConfig().jit_config_dict
|
|
153
158
|
if kwargs:
|
|
154
159
|
bound_arguments = inspect.signature(fn).bind(*args, **kwargs)
|
|
155
160
|
bound_arguments.apply_defaults()
|
|
@@ -170,14 +175,16 @@ def _jit_trace():
|
|
|
170
175
|
line_str = fn.__code__.co_filename + ":" + str(fn.__code__.co_firstlineno)
|
|
171
176
|
generate_name = generate_name + '#[' + line_str + ']'
|
|
172
177
|
|
|
173
|
-
new_compile = _jit_trace_begin(
|
|
178
|
+
new_compile = _jit_trace_begin(
|
|
179
|
+
generate_name, *jit_args, jit_config=jit_config_dict)
|
|
174
180
|
if new_compile:
|
|
175
181
|
fn_res = fn(*args, **kwargs)
|
|
176
182
|
logger.debug(f'fn: {fn}, fn_res: {fn_res}, line: {line_str}')
|
|
177
183
|
# Use fn's output to build func graph's output.
|
|
178
|
-
output = _jit_trace_end(fn_res)
|
|
184
|
+
output = _jit_trace_end(fn_res, jit_config=jit_config_dict)
|
|
179
185
|
else:
|
|
180
|
-
|
|
186
|
+
# Run with compilation.
|
|
187
|
+
output = _jit_trace_end(None, jit_config=jit_config_dict)
|
|
181
188
|
logger.debug(f'output: {output}')
|
|
182
189
|
return output
|
|
183
190
|
|
|
@@ -224,7 +231,7 @@ def _get_args_for_run(args):
|
|
|
224
231
|
return tuple(new_args)
|
|
225
232
|
|
|
226
233
|
|
|
227
|
-
def _jit_trace_begin(fn_name, *args):
|
|
234
|
+
def _jit_trace_begin(fn_name, *args, **kwargs):
|
|
228
235
|
"""
|
|
229
236
|
Start to build a MindIR func graph for a code snippet by trace method.
|
|
230
237
|
|
|
@@ -257,6 +264,10 @@ def _jit_trace_begin(fn_name, *args):
|
|
|
257
264
|
...
|
|
258
265
|
>>> out = tensor_add(x, y)
|
|
259
266
|
"""
|
|
267
|
+
if "jit_config" in kwargs:
|
|
268
|
+
jit_config = kwargs["jit_config"]
|
|
269
|
+
else:
|
|
270
|
+
jit_config = JitConfig().jit_config_dict
|
|
260
271
|
global _using_trace
|
|
261
272
|
if _using_trace:
|
|
262
273
|
raise RuntimeError(
|
|
@@ -279,7 +290,7 @@ def _jit_trace_begin(fn_name, *args):
|
|
|
279
290
|
if not _compile_only and phase in _trace_compile_cache:
|
|
280
291
|
logger.debug('Had compiled, just run.')
|
|
281
292
|
_trace_jit_context.compiled = True
|
|
282
|
-
output = tr.get_instance().run_graph(phase, args)
|
|
293
|
+
output = tr.get_instance().run_graph(phase, jit_config, args)
|
|
283
294
|
from mindspore.common.api import _convert_python_data
|
|
284
295
|
_trace_jit_context.result = _convert_python_data(output)
|
|
285
296
|
logger.debug(f'jit trace result: {_trace_jit_context.result}')
|
|
@@ -295,7 +306,7 @@ def _jit_trace_begin(fn_name, *args):
|
|
|
295
306
|
return True
|
|
296
307
|
|
|
297
308
|
|
|
298
|
-
def _jit_trace_end(*output_args):
|
|
309
|
+
def _jit_trace_end(*output_args, **kwargs):
|
|
299
310
|
"""
|
|
300
311
|
Finish building a MindIR func graph for a code snippet by trace method.
|
|
301
312
|
|
|
@@ -330,19 +341,23 @@ def _jit_trace_end(*output_args):
|
|
|
330
341
|
...
|
|
331
342
|
>>> out = tensor_add(x, y)
|
|
332
343
|
"""
|
|
344
|
+
if "jit_config" in kwargs:
|
|
345
|
+
jit_config = kwargs["jit_config"]
|
|
346
|
+
else:
|
|
347
|
+
jit_config = JitConfig().jit_config_dict
|
|
333
348
|
if _trace_jit_context.compiled:
|
|
334
349
|
output = _trace_jit_context.result
|
|
335
350
|
logger.debug(f'jit trace result: {output}')
|
|
336
351
|
else:
|
|
337
352
|
logger.debug(f'output_args: {output_args}')
|
|
338
353
|
file_names, linenos = _get_caller_lines()
|
|
339
|
-
tr.get_instance().end_graph(file_names, linenos, *output_args)
|
|
354
|
+
tr.get_instance().end_graph(file_names, linenos, jit_config, *output_args)
|
|
340
355
|
if _compile_only:
|
|
341
356
|
output = output_args[0] if len(output_args) == 1 else output_args
|
|
342
357
|
else:
|
|
343
358
|
args = _get_args_for_run(_trace_jit_context.args)
|
|
344
359
|
output = tr.get_instance().run_graph(
|
|
345
|
-
_trace_jit_context.phase, args)
|
|
360
|
+
_trace_jit_context.phase, jit_config, args)
|
|
346
361
|
from mindspore.common.api import _convert_python_data
|
|
347
362
|
output = _convert_python_data(output)
|
|
348
363
|
logger.debug(f'jit trace result: {output}')
|
mindspore/common/lazy_inline.py
CHANGED
|
@@ -32,9 +32,11 @@ def lazy_inline(fn=None, attrs=None, policy=None):
|
|
|
32
32
|
static_graph_expert_programming.html#using-lazy-inline-decorator>`_ .
|
|
33
33
|
|
|
34
34
|
.. warning::
|
|
35
|
-
This feature is only supported on Ascend and is not supported on other hardwares.
|
|
36
|
-
The construct parameters must be positional or key word arguments and have not default values.
|
|
37
|
-
The cell has not switch sub graph.
|
|
35
|
+
- This feature is only supported on Ascend and is not supported on other hardwares.
|
|
36
|
+
- The construct parameters must be positional or key word arguments and have not default values.
|
|
37
|
+
- The cell has not switch sub graph.
|
|
38
|
+
- In the gradient accumulation scenario, it is recommended to use the @lazy_inline decorator to
|
|
39
|
+
reduce compilation time, and this decorator is only allowed to configure on the outermost cell.
|
|
38
40
|
|
|
39
41
|
Args:
|
|
40
42
|
fn (function): `__init__` function of a cell.
|
mindspore/common/parameter.py
CHANGED
|
@@ -21,7 +21,6 @@ from copy import copy
|
|
|
21
21
|
import time
|
|
22
22
|
import os
|
|
23
23
|
import sys
|
|
24
|
-
import math
|
|
25
24
|
import numbers
|
|
26
25
|
import numpy as np
|
|
27
26
|
|
|
@@ -29,8 +28,6 @@ from mindspore import log as logger
|
|
|
29
28
|
from mindspore.log import _LogActionOnce
|
|
30
29
|
from mindspore._c_expression import ParamInfo
|
|
31
30
|
from mindspore.common import dtype as mstype
|
|
32
|
-
from mindspore import context
|
|
33
|
-
from mindspore.common._utils import get_slice_num, get_slice_shape
|
|
34
31
|
from mindspore.common.initializer import initializer
|
|
35
32
|
from mindspore.common.tensor import Tensor, _TensorMeta
|
|
36
33
|
from mindspore.common.hook_handle import _update_hook_version
|
|
@@ -39,10 +36,6 @@ from mindspore._check_jit_forbidden_api import jit_forbidden_register
|
|
|
39
36
|
from mindspore._c_expression import TensorPy as Tensor_
|
|
40
37
|
from mindspore.parallel._tensor import _get_slice_index
|
|
41
38
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
|
42
|
-
from mindspore.parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _clone_hash_table, \
|
|
43
|
-
_is_ps_mode
|
|
44
|
-
from mindspore.parallel._ps_context import _reinsert_hash_table_size, _insert_accumu_init_info, _cache_enable
|
|
45
|
-
from mindspore.common._decorator import deprecated
|
|
46
39
|
from mindspore.communication._comm_helper import _is_initialized
|
|
47
40
|
from mindspore.communication import get_group_size, get_rank
|
|
48
41
|
import mindspore.common._monad as monad
|
|
@@ -138,11 +131,7 @@ def _offload_if_config(data):
|
|
|
138
131
|
Args:
|
|
139
132
|
data: The parameter data to offload.
|
|
140
133
|
"""
|
|
141
|
-
if
|
|
142
|
-
return
|
|
143
|
-
|
|
144
|
-
offload_context = context.get_offload_context()
|
|
145
|
-
if offload_context.get("offload_param", None) != "disk":
|
|
134
|
+
if data is None:
|
|
146
135
|
return
|
|
147
136
|
|
|
148
137
|
data_size_threshold = 512
|
|
@@ -219,7 +208,10 @@ class Parameter(Tensor_):
|
|
|
219
208
|
self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
|
|
220
209
|
self.param_tuple = (self.param_a, self.param_a)
|
|
221
210
|
|
|
222
|
-
requires_grad (bool):
|
|
211
|
+
requires_grad (bool): It is Used to filter parameters in :func:`mindspore.nn.Cell.trainable_params()`.
|
|
212
|
+
If it is ``False``, the filter parameters will not be returned in
|
|
213
|
+
:func:`mindspore.nn.Cell.trainable_params()`.
|
|
214
|
+
Default: ``True`` .
|
|
223
215
|
layerwise_parallel (bool): When `layerwise_parallel` is true in data/hybrid parallel mode,
|
|
224
216
|
broadcast and gradients communication would not be applied to the `Parameter`. Default: ``False`` .
|
|
225
217
|
parallel_optimizer (bool): It is used to filter the weight shard operation in parallel mode. It works only when
|
|
@@ -230,10 +222,8 @@ class Parameter(Tensor_):
|
|
|
230
222
|
device(str): Only Ascend device target is supported. It is used to specify the device which the parameter is
|
|
231
223
|
stored. By default, the parameter will be stored on NPU while computing. When the device is specified as
|
|
232
224
|
``"CPU"``, the parameter will be loaded into the device when it needs to be used, and unloaded to the CPU
|
|
233
|
-
after use. It takes effext only when `
|
|
234
|
-
|
|
235
|
-
Less device memory is needed when device is
|
|
236
|
-
specified as ``"CPU"``.
|
|
225
|
+
after use. It takes effext only when `jit_level` is not ``"O2"`` and `memory_optimize_level` is ``O0``
|
|
226
|
+
in :func:`mindspore.set_context`. Less device memory is needed when device is specified as ``"CPU"``.
|
|
237
227
|
|
|
238
228
|
Examples:
|
|
239
229
|
>>> import numpy as np
|
|
@@ -272,8 +262,6 @@ class Parameter(Tensor_):
|
|
|
272
262
|
obj.is_default_input_init = init_data_flag
|
|
273
263
|
if obj.has_init:
|
|
274
264
|
obj.init_mode = default_input
|
|
275
|
-
else:
|
|
276
|
-
_offload_if_config(obj)
|
|
277
265
|
return obj
|
|
278
266
|
|
|
279
267
|
def __reduce_ex__(self, _):
|
|
@@ -289,7 +277,6 @@ class Parameter(Tensor_):
|
|
|
289
277
|
def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False, parallel_optimizer=True,
|
|
290
278
|
storage_format="", device=None):
|
|
291
279
|
self.param_info = ParamInfo()
|
|
292
|
-
self.init_in_server = False
|
|
293
280
|
self.name = name
|
|
294
281
|
self.requires_grad = requires_grad
|
|
295
282
|
self.layerwise_parallel = layerwise_parallel
|
|
@@ -300,32 +287,15 @@ class Parameter(Tensor_):
|
|
|
300
287
|
self.is_init = False
|
|
301
288
|
self._inited_param = None
|
|
302
289
|
self._sliced = False
|
|
303
|
-
self.is_param_ps = False
|
|
304
|
-
self.push_weight_to_server = False
|
|
305
|
-
self.pull_weight_from_server = False
|
|
306
290
|
self.requires_aggr = True
|
|
307
291
|
self._cast_type = None
|
|
308
292
|
self._unique = False
|
|
309
293
|
self.is_in_parallel = _is_in_auto_parallel_mode()
|
|
310
294
|
self._pipeline_stage_list = []
|
|
311
|
-
self.slice_num = 1
|
|
312
295
|
if -1 in self.shape:
|
|
313
296
|
raise ValueError(f"All shape elements of the Parameter must be positive. But got None.")
|
|
314
297
|
if isinstance(default_input, (Tensor_, Tensor)):
|
|
315
|
-
|
|
316
|
-
# And save out range data to persistent storage to support TB-Level size parameter.
|
|
317
|
-
slice_num_of_persistent_data = get_slice_num(default_input.dtype, default_input.shape)
|
|
318
|
-
if slice_num_of_persistent_data > 1:
|
|
319
|
-
data_shape = list(default_input.shape)
|
|
320
|
-
slice_first_dim = math.ceil(data_shape[0] / slice_num_of_persistent_data)
|
|
321
|
-
data_shape[0] = slice_first_dim
|
|
322
|
-
self.param_info.use_persistent_storage = True
|
|
323
|
-
self.param_info.origin_shape = default_input.shape
|
|
324
|
-
self.slice_num = slice_num_of_persistent_data
|
|
325
|
-
Tensor_.__init__(self, dtype=default_input.dtype, shape=tuple(data_shape))
|
|
326
|
-
else:
|
|
327
|
-
Tensor_.__init__(self, dtype=default_input.dtype, shape=default_input.shape)
|
|
328
|
-
|
|
298
|
+
Tensor_.__init__(self, dtype=default_input.dtype, shape=default_input.shape)
|
|
329
299
|
elif isinstance(default_input, int):
|
|
330
300
|
Tensor_.__init__(self, dtype=mstype.int64, shape=())
|
|
331
301
|
elif isinstance(default_input, float):
|
|
@@ -387,11 +357,10 @@ class Parameter(Tensor_):
|
|
|
387
357
|
return (Tensor, data.asnumpy(), mstype.qint4x2)
|
|
388
358
|
return (Tensor, data.asnumpy())
|
|
389
359
|
|
|
390
|
-
not_init_data = not init_param or
|
|
391
|
-
or _is_in_auto_parallel_mode() or _is_parallel_mode()
|
|
360
|
+
not_init_data = not init_param or _is_in_auto_parallel_mode() or _is_parallel_mode()
|
|
392
361
|
if not_init_data:
|
|
393
362
|
# do not init data while in auto parallel.
|
|
394
|
-
return (Tensor, None, data.dtype,
|
|
363
|
+
return (Tensor, None, data.dtype, data.shape, data.init)
|
|
395
364
|
return (Tensor, data.init_data())
|
|
396
365
|
if isinstance(data, int):
|
|
397
366
|
return (Tensor, data, mstype.int32)
|
|
@@ -399,29 +368,6 @@ class Parameter(Tensor_):
|
|
|
399
368
|
return (Tensor, data, mstype.float32)
|
|
400
369
|
return (Tensor, data)
|
|
401
370
|
|
|
402
|
-
def set_param_ps(self, init_in_server=False):
|
|
403
|
-
"""
|
|
404
|
-
Set whether the trainable parameter is updated by parameter server and whether the
|
|
405
|
-
trainable parameter is initialized on server.
|
|
406
|
-
|
|
407
|
-
Note:
|
|
408
|
-
It only works when a running task is in the parameter server mode.
|
|
409
|
-
It is supported only in graph mode.
|
|
410
|
-
|
|
411
|
-
Args:
|
|
412
|
-
init_in_server (bool): Whether trainable parameter updated by parameter server is
|
|
413
|
-
initialized on server. Default: ``False``.
|
|
414
|
-
|
|
415
|
-
"""
|
|
416
|
-
if not _is_ps_mode() or not (_is_role_worker() or _is_role_pserver() or _is_role_sched()):
|
|
417
|
-
raise RuntimeError("Must complete following two steps before calling set_param_ps: \n"
|
|
418
|
-
"1. context.set_ps_context(enable_ps=True) \n"
|
|
419
|
-
"2. export MS_ROLE environment variable \n"
|
|
420
|
-
"Please refer to the official website for detailed usage.")
|
|
421
|
-
self.is_param_ps = True
|
|
422
|
-
self.init_in_server = init_in_server
|
|
423
|
-
self.param_info.init_in_server = init_in_server
|
|
424
|
-
|
|
425
371
|
def copy(self):
|
|
426
372
|
"""
|
|
427
373
|
Copy the parameter.
|
|
@@ -437,16 +383,6 @@ class Parameter(Tensor_):
|
|
|
437
383
|
"""
|
|
438
384
|
return self.clone(init='same')
|
|
439
385
|
|
|
440
|
-
@deprecated("1.8", "set_param_fl")
|
|
441
|
-
def set_param_fl(self, push_to_server=False, pull_from_server=False, requires_aggr=True):
|
|
442
|
-
if push_to_server:
|
|
443
|
-
self.push_weight_to_server = True
|
|
444
|
-
if pull_from_server:
|
|
445
|
-
self.pull_weight_from_server = True
|
|
446
|
-
if not requires_aggr:
|
|
447
|
-
self.requires_aggr = False
|
|
448
|
-
self.param_info.requires_aggr = False
|
|
449
|
-
|
|
450
386
|
@property
|
|
451
387
|
def inited_param(self):
|
|
452
388
|
"""
|
|
@@ -512,8 +448,6 @@ class Parameter(Tensor_):
|
|
|
512
448
|
raise ValueError("The type of the Parameter's name should be 'string' or 'None', "
|
|
513
449
|
"but got {}.".format(type(name_)))
|
|
514
450
|
|
|
515
|
-
if _is_role_worker() and self.cache_enable:
|
|
516
|
-
_reinsert_hash_table_size(name_, self.param_info.name)
|
|
517
451
|
self.param_info.name = name_
|
|
518
452
|
|
|
519
453
|
@property
|
|
@@ -642,8 +576,6 @@ class Parameter(Tensor_):
|
|
|
642
576
|
x.param_info = param_info_clone
|
|
643
577
|
x.is_init = False
|
|
644
578
|
x.init = self.init
|
|
645
|
-
x.is_param_ps = self.is_param_ps
|
|
646
|
-
x.init_in_server = self.init_in_server
|
|
647
579
|
x.cache_enable = self.cache_enable
|
|
648
580
|
if x.cache_enable:
|
|
649
581
|
x.key = _get_unique_parameter_key()
|
|
@@ -651,7 +583,7 @@ class Parameter(Tensor_):
|
|
|
651
583
|
if self.cache_shape:
|
|
652
584
|
x.cache_shape = self.cache_shape
|
|
653
585
|
if init != 'same':
|
|
654
|
-
shape = self.shape
|
|
586
|
+
shape = self.shape
|
|
655
587
|
dtype = self.dtype
|
|
656
588
|
tensor = initializer(init, shape=shape, dtype=dtype)
|
|
657
589
|
x.set_data(tensor)
|
|
@@ -796,6 +728,7 @@ class Parameter(Tensor_):
|
|
|
796
728
|
raise TypeError("The argument `requires_grad` must be bool type")
|
|
797
729
|
Tensor_.wait_pipeline(self)
|
|
798
730
|
self.param_info.requires_grad = value
|
|
731
|
+
self._requires_grad = value
|
|
799
732
|
|
|
800
733
|
@property
|
|
801
734
|
def data(self):
|
|
@@ -862,20 +795,6 @@ class Parameter(Tensor_):
|
|
|
862
795
|
raise TypeError("The original tensor data is initialized, but the argument 'data' is not initialized."
|
|
863
796
|
"Please initialize 'data' before call this method.")
|
|
864
797
|
|
|
865
|
-
@staticmethod
|
|
866
|
-
def _from_tensor(tensor, *args, **kwargs):
|
|
867
|
-
"""Create a `Parameter` that data is shared from a `Tensor`."""
|
|
868
|
-
if not isinstance(tensor, Tensor_):
|
|
869
|
-
raise TypeError(f"The type of input must be Tensor, but got {type(tensor)}.")
|
|
870
|
-
param = Tensor_.__new__(Parameter)
|
|
871
|
-
Tensor_.__init__(param, tensor)
|
|
872
|
-
param.init = None
|
|
873
|
-
param.init_mode = None
|
|
874
|
-
param.has_init = False
|
|
875
|
-
param.is_default_input_init = False
|
|
876
|
-
Parameter.__init__(param, tensor, *args, **kwargs)
|
|
877
|
-
return param
|
|
878
|
-
|
|
879
798
|
@jit_forbidden_register
|
|
880
799
|
def set_data(self, data, slice_shape=False):
|
|
881
800
|
"""
|
|
@@ -981,16 +900,7 @@ class Parameter(Tensor_):
|
|
|
981
900
|
|
|
982
901
|
init_data_args = self._get_init_data_args(layout)
|
|
983
902
|
|
|
984
|
-
|
|
985
|
-
return self
|
|
986
|
-
if self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Tensor) and \
|
|
987
|
-
self.init_mode.init is not None and _is_role_worker():
|
|
988
|
-
if self.cache_enable:
|
|
989
|
-
data = self.init_mode.init_data(*init_data_args)
|
|
990
|
-
else:
|
|
991
|
-
data = self.init_mode.init_data(0, [1])
|
|
992
|
-
else:
|
|
993
|
-
data = self.init_mode.init_data(*init_data_args)
|
|
903
|
+
data = self.init_mode.init_data(*init_data_args)
|
|
994
904
|
origin_dtype = self.dtype
|
|
995
905
|
obj = self._update_tensor_data(data)
|
|
996
906
|
if self.dtype != origin_dtype:
|
|
@@ -999,7 +909,6 @@ class Parameter(Tensor_):
|
|
|
999
909
|
self._inited_param = obj
|
|
1000
910
|
obj.init_mode = None
|
|
1001
911
|
obj.sliced = set_sliced
|
|
1002
|
-
_offload_if_config(obj)
|
|
1003
912
|
return obj
|
|
1004
913
|
|
|
1005
914
|
def register_hook(self, hook_fn):
|
|
@@ -1154,9 +1063,6 @@ class ParameterTuple(tuple):
|
|
|
1154
1063
|
if not x1.cache_enable:
|
|
1155
1064
|
continue
|
|
1156
1065
|
|
|
1157
|
-
if _is_role_worker():
|
|
1158
|
-
_clone_hash_table(x.name, x.key, x1.name, x1.key)
|
|
1159
|
-
_insert_accumu_init_info(x1.name, init_to_value(init))
|
|
1160
1066
|
return ParameterTuple(new)
|
|
1161
1067
|
|
|
1162
1068
|
def __parameter_tuple__(self):
|
mindspore/common/recompute.py
CHANGED
|
@@ -22,11 +22,10 @@ from mindspore.common.tensor import Tensor
|
|
|
22
22
|
from mindspore import ops
|
|
23
23
|
from mindspore.ops.composite import GradOperation
|
|
24
24
|
from mindspore.common._register_for_recompute import recompute_registry
|
|
25
|
-
from mindspore.common.api import _pynative_executor, _no_grad
|
|
25
|
+
from mindspore.common.api import _pynative_executor, _no_grad, _run_in_jit
|
|
26
26
|
from mindspore.common.generator import get_rng_state, set_rng_state
|
|
27
27
|
from mindspore.train.amp import AmpDecorator
|
|
28
28
|
from mindspore._c_expression.amp import get_curr_amp_strategy
|
|
29
|
-
from mindspore._check_jit_forbidden_api import jit_forbidden_register
|
|
30
29
|
|
|
31
30
|
|
|
32
31
|
class _WrapCell(Cell):
|
|
@@ -211,22 +210,15 @@ def _detach_input(input_arg):
|
|
|
211
210
|
def _check_validation(block):
|
|
212
211
|
if not isinstance(block, Cell):
|
|
213
212
|
raise TypeError("Recompute function now only support block which inherited from Cell!")
|
|
214
|
-
if block.construct.__code__.co_name == "staging_specialize":
|
|
215
|
-
logger.warning('Block\'s construct method decorated by @jit that recompute '
|
|
216
|
-
'function will not come into effect.')
|
|
217
213
|
|
|
218
214
|
|
|
219
|
-
@jit_forbidden_register
|
|
220
215
|
def recompute(block, *args, **kwargs):
|
|
221
216
|
r"""
|
|
222
217
|
This function is used to reduce memory, when run block, rather than
|
|
223
218
|
storing the intermediate activation computed in forward pass, we will recompute it in backward pass.
|
|
224
219
|
|
|
225
220
|
Note:
|
|
226
|
-
|
|
227
|
-
- This function interface now only support pynative mode. you can use Cell.recompute interface
|
|
228
|
-
in graph mode.
|
|
229
|
-
- When use recompute function, block object should not decorated by @jit.
|
|
221
|
+
Recompute function only support block which inherited from Cell object.
|
|
230
222
|
|
|
231
223
|
Args:
|
|
232
224
|
block (Cell): Block to be recompute.
|
|
@@ -238,7 +230,6 @@ def recompute(block, *args, **kwargs):
|
|
|
238
230
|
|
|
239
231
|
Raises:
|
|
240
232
|
TypeError: If `block` is not Cell object.
|
|
241
|
-
AssertionError: If execute mode is not PYNATIVE_MODE.
|
|
242
233
|
|
|
243
234
|
Supported Platforms:
|
|
244
235
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -272,6 +263,8 @@ def recompute(block, *args, **kwargs):
|
|
|
272
263
|
"""
|
|
273
264
|
|
|
274
265
|
_check_validation(block)
|
|
266
|
+
if _run_in_jit(): # @jit.cond: True
|
|
267
|
+
return ops.recompute_block(block)(*args, **kwargs)
|
|
275
268
|
return _RecomputeCell(block)(*args, **kwargs)
|
|
276
269
|
|
|
277
270
|
|