mindspore 2.7.0rc1__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +5 -2
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +2 -2
- mindspore/_extends/builtin_operations.py +3 -3
- mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +3 -3
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
- mindspore/_extends/parse/parser.py +28 -22
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +23 -2
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
- mindspore/amp.py +0 -18
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/__init__.py +18 -12
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +371 -96
- mindspore/common/_utils.py +7 -43
- mindspore/common/api.py +434 -135
- mindspore/common/dtype.py +98 -57
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/__init__.py +0 -0
- mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
- mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
- mindspore/common/file_system.py +59 -9
- mindspore/common/hook_handle.py +82 -3
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/np_dtype.py +3 -3
- mindspore/common/parameter.py +17 -127
- mindspore/common/recompute.py +4 -13
- mindspore/common/tensor.py +50 -217
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +20 -106
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/transforms.py +1 -1
- mindspore/dataset/core/config.py +35 -1
- mindspore/dataset/engine/datasets.py +338 -319
- mindspore/dataset/engine/datasets_user_defined.py +38 -22
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dataset/transforms/c_transforms.py +2 -2
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/vision/__init__.py +1 -1
- mindspore/dataset/vision/py_transforms.py +8 -8
- mindspore/dataset/vision/transforms.py +17 -5
- mindspore/dataset/vision/utils.py +632 -21
- mindspore/device_context/ascend/op_tuning.py +35 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/api/cell.h +28 -4
- mindspore/include/api/cfg.h +24 -7
- mindspore/include/api/context.h +1 -0
- mindspore/include/api/delegate.h +0 -2
- mindspore/include/api/dual_abi_helper.h +100 -19
- mindspore/include/api/graph.h +14 -1
- mindspore/include/api/kernel.h +16 -3
- mindspore/include/api/kernel_api.h +9 -1
- mindspore/include/api/metrics/accuracy.h +9 -0
- mindspore/include/api/model.h +5 -1
- mindspore/include/api/model_group.h +4 -0
- mindspore/include/api/model_parallel_runner.h +2 -0
- mindspore/include/api/status.h +48 -10
- mindspore/include/api/types.h +6 -1
- mindspore/include/dataset/constants.h +9 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +4 -3
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/__init__.py +4 -0
- mindspore/mint/distributed/distributed.py +392 -69
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/_functions.py +1 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +10 -10
- mindspore/mint/nn/layer/normalization.py +11 -16
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +231 -239
- mindspore/nn/layer/activation.py +4 -2
- mindspore/nn/layer/basic.py +56 -14
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/image.py +1 -1
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +32 -127
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +1 -4
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +2 -4
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/probability/distribution/_utils/utils.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +39 -5
- mindspore/nn/wrap/grad_reducer.py +4 -89
- mindspore/numpy/array_creations.py +4 -4
- mindspore/numpy/fft.py +9 -9
- mindspore/numpy/utils_const.py +1 -1
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +1 -5
- mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
- mindspore/ops/auto_generate/gen_extend_func.py +6 -11
- mindspore/ops/auto_generate/gen_ops_def.py +385 -154
- mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +16 -2
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +2 -0
- mindspore/ops/function/array_func.py +24 -18
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +7 -6
- mindspore/ops/function/grad/grad_func.py +4 -12
- mindspore/ops/function/math_func.py +89 -86
- mindspore/ops/function/nn_func.py +92 -313
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +4 -1
- mindspore/ops/functional_overload.py +377 -30
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +12 -50
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +5 -50
- mindspore/ops/operations/comm_ops.py +95 -17
- mindspore/ops/operations/custom_ops.py +237 -22
- mindspore/ops/operations/debug_ops.py +33 -35
- mindspore/ops/operations/manually_defined/ops_def.py +39 -318
- mindspore/ops/operations/math_ops.py +5 -5
- mindspore/ops/operations/nn_ops.py +3 -3
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +4 -27
- mindspore/ops/tensor_method.py +88 -10
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_auto_parallel_context.py +5 -15
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +4 -6
- mindspore/parallel/_ps_context.py +2 -2
- mindspore/parallel/_utils.py +34 -17
- mindspore/parallel/auto_parallel.py +23 -9
- mindspore/parallel/checkpoint_transform.py +20 -2
- mindspore/parallel/cluster/process_entity/_api.py +28 -33
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/parallel/cluster/run.py +5 -3
- mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/function/reshard_func.py +6 -5
- mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
- mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
- mindspore/parallel/shard.py +7 -21
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +127 -20
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +40 -4
- mindspore/profiler/common/path_manager.py +65 -24
- mindspore/profiler/common/profiler_context.py +27 -14
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_meta_data.py +1 -0
- mindspore/profiler/common/profiler_op_analyse.py +10 -6
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/dynamic_profiler.py +91 -46
- mindspore/profiler/envprofiler.py +30 -5
- mindspore/profiler/experimental_config.py +18 -2
- mindspore/profiler/platform/cpu_profiler.py +10 -4
- mindspore/profiler/platform/npu_profiler.py +34 -7
- mindspore/profiler/profiler.py +193 -145
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +2 -2
- mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +9 -6
- mindspore/runtime/executor.py +35 -0
- mindspore/runtime/memory.py +113 -0
- mindspore/runtime/thread_bind_core.py +1 -1
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +8 -21
- mindspore/train/amp.py +6 -7
- mindspore/train/callback/_callback.py +2 -1
- mindspore/train/callback/_checkpoint.py +1 -17
- mindspore/train/callback/_flops_collector.py +10 -6
- mindspore/train/callback/_train_fault_tolerance.py +72 -25
- mindspore/train/data_sink.py +5 -9
- mindspore/train/dataset_helper.py +5 -5
- mindspore/train/model.py +41 -230
- mindspore/train/serialization.py +160 -401
- mindspore/train/train_thor/model_thor.py +2 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +152 -16
- mindspore/version.py +1 -1
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
- mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
- mindspore/communication/_hccl_management.py +0 -297
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/__init__.py +0 -23
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/profiler/common/validator/validate_path.py +0 -84
- mindspore/train/memory_profiling_pb2.py +0 -298
- mindspore/utils/hooks.py +0 -81
- /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py
CHANGED
|
@@ -39,10 +39,11 @@ from typing import (
|
|
|
39
39
|
|
|
40
40
|
import weakref
|
|
41
41
|
import mindspore as ms
|
|
42
|
+
import mindspore.ops as ops
|
|
42
43
|
from mindspore._checkparam import args_type_check, check_hook_fn
|
|
43
|
-
from mindspore.common._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic
|
|
44
|
+
from mindspore.common.dynamic_shape._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic
|
|
44
45
|
from mindspore import log as logger
|
|
45
|
-
from mindspore.common.hook_handle import HookHandle
|
|
46
|
+
from mindspore.common.hook_handle import HookHandle, _update_hook_version
|
|
46
47
|
from mindspore import context
|
|
47
48
|
from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
|
|
48
49
|
from mindspore import _checkparam as Validator
|
|
@@ -92,9 +93,8 @@ def register_cell_buffer_registration_hook(hook: Callable[..., None],):
|
|
|
92
93
|
A handle that can be used to remove the added hook by calling
|
|
93
94
|
`handle.remove()`.
|
|
94
95
|
"""
|
|
95
|
-
|
|
96
|
-
handle =
|
|
97
|
-
_global_buffer_registration_hooks[handle.id] = hook
|
|
96
|
+
handle = HookHandle(_global_buffer_registration_hooks)
|
|
97
|
+
_global_buffer_registration_hooks[handle.handle_id] = hook
|
|
98
98
|
return handle
|
|
99
99
|
|
|
100
100
|
|
|
@@ -155,7 +155,8 @@ class Cell(Cell_):
|
|
|
155
155
|
IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_create_time',
|
|
156
156
|
'_func_graph_flags', '_parameter_layout_dict', '_params_list', '_phase', '_bprop_debug',
|
|
157
157
|
'_forward_pre_hook', '_forward_hook', '_backward_pre_hook', '_backward_hook',
|
|
158
|
-
'_cell_backward_pre_hook', '_cell_backward_hook', '_param_prefix',
|
|
158
|
+
'_cell_backward_pre_hook', '_cell_backward_hook', '_param_prefix',
|
|
159
|
+
'requires_grad', 'cell_type', '_in_strategy', '_out_strategy']
|
|
159
160
|
total_instance_count = 0
|
|
160
161
|
_buffers: Dict[str, Optional[Tensor]]
|
|
161
162
|
global_cells = weakref.WeakKeyDictionary()
|
|
@@ -191,6 +192,7 @@ class Cell(Cell_):
|
|
|
191
192
|
super().__setattr__("_auto_prefix", auto_prefix)
|
|
192
193
|
super().__setattr__("_scope", None)
|
|
193
194
|
super().__setattr__("_phase", 'train')
|
|
195
|
+
super().__setattr__("_compile_phase", None)
|
|
194
196
|
super().__setattr__("_parameter_layout_dict", None)
|
|
195
197
|
super().__setattr__("_parallel_parameter_name_list", None)
|
|
196
198
|
super().__setattr__("_parallel_parameter_merge_net_dict", None)
|
|
@@ -206,6 +208,7 @@ class Cell(Cell_):
|
|
|
206
208
|
super().__setattr__("mixed_precision_type", None)
|
|
207
209
|
super().__setattr__("_lazy_construct_sig", None)
|
|
208
210
|
super().__setattr__("_jit_graph_name", '')
|
|
211
|
+
super().__setattr__("_compiled", False)
|
|
209
212
|
init_pipeline()
|
|
210
213
|
|
|
211
214
|
# call gc to release GE session resources used by non-used cell objects
|
|
@@ -239,6 +242,8 @@ class Cell(Cell_):
|
|
|
239
242
|
super().__setattr__("_amp_level", "")
|
|
240
243
|
super().__setattr__("_init_flag", False)
|
|
241
244
|
super().__setattr__("_shard_fn", None)
|
|
245
|
+
super().__setattr__("_in_strategy", None)
|
|
246
|
+
super().__setattr__("_out_strategy", None)
|
|
242
247
|
super().__setattr__("has_bprop", False)
|
|
243
248
|
if hasattr(self, "bprop"):
|
|
244
249
|
super().__setattr__("has_bprop", True)
|
|
@@ -426,6 +431,13 @@ class Cell(Cell_):
|
|
|
426
431
|
"""
|
|
427
432
|
return self._bprop_debug
|
|
428
433
|
|
|
434
|
+
@property
|
|
435
|
+
def compiled(self):
|
|
436
|
+
"""
|
|
437
|
+
Get whether `Cell` is compiled in graph mode.
|
|
438
|
+
"""
|
|
439
|
+
return self._compiled
|
|
440
|
+
|
|
429
441
|
@bprop_debug.setter
|
|
430
442
|
def bprop_debug(self, value):
|
|
431
443
|
"""
|
|
@@ -482,6 +494,19 @@ class Cell(Cell_):
|
|
|
482
494
|
raise TypeError(f"For 'Cell', the property 'phase' must be string type, but got type {type(value)}.")
|
|
483
495
|
self._phase = value
|
|
484
496
|
|
|
497
|
+
@property
|
|
498
|
+
def compile_phase(self):
|
|
499
|
+
return self._compile_phase
|
|
500
|
+
|
|
501
|
+
@compile_phase.setter
|
|
502
|
+
def compile_phase(self, value):
|
|
503
|
+
if not isinstance(value, str):
|
|
504
|
+
raise TypeError(f"For 'Cell', 'compile_phase' must be string type, but got type {type(value)}.")
|
|
505
|
+
self._compile_phase = value
|
|
506
|
+
for cell in self._cells.values():
|
|
507
|
+
if cell is not None:
|
|
508
|
+
cell.compile_phase = value
|
|
509
|
+
|
|
485
510
|
@property
|
|
486
511
|
def parameter_layout_dict(self):
|
|
487
512
|
"""
|
|
@@ -546,10 +571,23 @@ class Cell(Cell_):
|
|
|
546
571
|
|
|
547
572
|
@property
|
|
548
573
|
def pipeline_segment(self):
|
|
574
|
+
"""
|
|
575
|
+
`pipeline_segment` represents the pipeline segment of current Cell.
|
|
576
|
+
"""
|
|
549
577
|
return self._pipeline_segment
|
|
550
578
|
|
|
551
579
|
@pipeline_segment.setter
|
|
552
580
|
def pipeline_segment(self, value):
|
|
581
|
+
"""
|
|
582
|
+
Set the `pipeline_segment` of a Cell. Only effective in zero_bubble_v scheduler.
|
|
583
|
+
|
|
584
|
+
Args:
|
|
585
|
+
value (int): The pipeline segment of a parameter.
|
|
586
|
+
|
|
587
|
+
Raises:
|
|
588
|
+
TypeError: If `value` is not int type or is a bool type.
|
|
589
|
+
ValueError: If `value` is not a positive integer.
|
|
590
|
+
"""
|
|
553
591
|
if not isinstance(value, int) or isinstance(value, bool):
|
|
554
592
|
raise TypeError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' "
|
|
555
593
|
"must be int type, but got type : {}".format(type(value)))
|
|
@@ -1027,12 +1065,13 @@ class Cell(Cell_):
|
|
|
1027
1065
|
if self._forward_pre_hook:
|
|
1028
1066
|
args, kwargs = self._run_forward_pre_hook(args, kwargs)
|
|
1029
1067
|
|
|
1068
|
+
if self._backward_hook:
|
|
1069
|
+
args = self._cell_backward_hook(args)
|
|
1070
|
+
|
|
1030
1071
|
if self._shard_fn is not None:
|
|
1031
1072
|
output = self._shard_fn(*args, **kwargs)
|
|
1032
1073
|
elif _pynative_executor.requires_grad():
|
|
1033
|
-
if self.
|
|
1034
|
-
output = self._backward_hook_construct(*args, **kwargs)
|
|
1035
|
-
elif self._recompute_cell is not None:
|
|
1074
|
+
if self._recompute_cell is not None:
|
|
1036
1075
|
output = self._recompute_cell(*args, **kwargs)
|
|
1037
1076
|
elif self.has_bprop:
|
|
1038
1077
|
output = self._call_custom_bprop(*args, **kwargs)
|
|
@@ -1044,8 +1083,11 @@ class Cell(Cell_):
|
|
|
1044
1083
|
if self._forward_hook:
|
|
1045
1084
|
output = self._run_forward_hook(args, kwargs, output)
|
|
1046
1085
|
|
|
1047
|
-
if self.
|
|
1048
|
-
output = self.
|
|
1086
|
+
if self._backward_hook:
|
|
1087
|
+
output = self._cell_backward_hook(output)
|
|
1088
|
+
|
|
1089
|
+
if self._backward_pre_hook:
|
|
1090
|
+
output = self._cell_backward_pre_hook(output)
|
|
1049
1091
|
|
|
1050
1092
|
return output
|
|
1051
1093
|
|
|
@@ -1080,23 +1122,6 @@ class Cell(Cell_):
|
|
|
1080
1122
|
f"{default_args} default argument, total {positional_args + default_args}, "
|
|
1081
1123
|
f"but got {len(args)}.")
|
|
1082
1124
|
|
|
1083
|
-
# pylint: disable=E0203
|
|
1084
|
-
def _hook_fn_registered(self):
|
|
1085
|
-
'''Hook function in graph mode'''
|
|
1086
|
-
# Check super().__init__() in graph mode.
|
|
1087
|
-
try:
|
|
1088
|
-
if self._forward_pre_hook or self._forward_hook or self._backward_pre_hook or self._backward_hook:
|
|
1089
|
-
return True
|
|
1090
|
-
except AttributeError as e:
|
|
1091
|
-
raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
|
|
1092
|
-
f"Please use 'super().__init__()'.") from e
|
|
1093
|
-
if not self._is_recursion_hook:
|
|
1094
|
-
self._is_recursion_hook = True
|
|
1095
|
-
for cell in self.cells():
|
|
1096
|
-
if cell._hook_fn_registered():
|
|
1097
|
-
return True
|
|
1098
|
-
return False
|
|
1099
|
-
|
|
1100
1125
|
def _get_prims_recursively(self):
|
|
1101
1126
|
all_prims = list()
|
|
1102
1127
|
for _, value in self._primitives.items():
|
|
@@ -1122,9 +1147,6 @@ class Cell(Cell_):
|
|
|
1122
1147
|
>>> net = nn.Dense(3, 4)
|
|
1123
1148
|
>>> net.set_data_parallel()
|
|
1124
1149
|
"""
|
|
1125
|
-
if context._get_mode() == context.PYNATIVE_MODE:
|
|
1126
|
-
raise ValueError("set_data_parallel: does not support PyNative mode.")
|
|
1127
|
-
|
|
1128
1150
|
all_prims = self._get_prims_recursively()
|
|
1129
1151
|
for prim in all_prims:
|
|
1130
1152
|
prim.add_prim_attr("strategy_gen_mode", "data_parallel")
|
|
@@ -1203,8 +1225,6 @@ class Cell(Cell_):
|
|
|
1203
1225
|
... out = self.blocks[i](out)
|
|
1204
1226
|
... return out
|
|
1205
1227
|
"""
|
|
1206
|
-
if context._get_mode() == context.PYNATIVE_MODE:
|
|
1207
|
-
raise ValueError("The Cell offload does not support PyNative mode now.")
|
|
1208
1228
|
if isinstance(backward_prefetch, str):
|
|
1209
1229
|
Validator.check_string(backward_prefetch, ['Auto'], 'backward_prefetch', self.cls_name)
|
|
1210
1230
|
else:
|
|
@@ -1212,11 +1232,10 @@ class Cell(Cell_):
|
|
|
1212
1232
|
for prim in self._get_prims_recursively():
|
|
1213
1233
|
prim._offload(backward_prefetch=backward_prefetch)
|
|
1214
1234
|
|
|
1215
|
-
def shard(self, in_strategy, out_strategy=None, parameter_plan=None
|
|
1235
|
+
def shard(self, in_strategy, out_strategy=None, parameter_plan=None):
|
|
1216
1236
|
"""
|
|
1217
1237
|
Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be
|
|
1218
|
-
generated by sharding propagation. In
|
|
1219
|
-
execution in graph mode. In Graph mode, use this method to specify distribution strategy for a Cell,
|
|
1238
|
+
generated by sharding propagation. In Graph mode, use this method to specify distribution strategy for a Cell,
|
|
1220
1239
|
strategy for others will be set by sharding propagation.
|
|
1221
1240
|
in_strategy and out_strategy define the input and output layout respectively.
|
|
1222
1241
|
in_strategy/out_strategy should be a tuple, each element of which corresponds to the desired layout of
|
|
@@ -1228,11 +1247,14 @@ class Cell(Cell_):
|
|
|
1228
1247
|
In other parallel modes, strategies set here will be ignored.
|
|
1229
1248
|
- If the input contain Parameter, its strategy should be set in `in_strategy`.
|
|
1230
1249
|
|
|
1250
|
+
.. warning::
|
|
1251
|
+
The method is currently not supported in PyNative mode.
|
|
1252
|
+
|
|
1231
1253
|
Args:
|
|
1232
1254
|
in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple. Tuple
|
|
1233
1255
|
defines the layout of the corresponding input.
|
|
1234
1256
|
out_strategy (Union[None, tuple]): Define the layout of outputs similar with in_strategy.
|
|
1235
|
-
|
|
1257
|
+
Default: ``None`` .
|
|
1236
1258
|
parameter_plan (Union[dict, None]): Define the layout for the specified parameters. Each element in dict
|
|
1237
1259
|
defines the layout of the parameter like "param_name: layout".
|
|
1238
1260
|
The key is a parameter name of type 'str'.
|
|
@@ -1240,14 +1262,6 @@ class Cell(Cell_):
|
|
|
1240
1262
|
If the parameter name is incorrect or the corresponding parameter
|
|
1241
1263
|
has been set, the parameter setting will be ignored.
|
|
1242
1264
|
Default: ``None`` .
|
|
1243
|
-
device (str): Select a certain device target. It is not in use right now.
|
|
1244
|
-
Support [ ``"CPU"`` , ``"GPU"`` , ``"Ascend"`` ]. Default: ``"Ascend"`` .
|
|
1245
|
-
level (int): Option for parallel strategy infer algorithm, namely the object function, maximize computation
|
|
1246
|
-
over communication ratio, maximize speed performance, minimize memory usage etc. It is not in
|
|
1247
|
-
use right now. Support [ ``"0"`` , ``"1"`` , ``"2"`` ]. Default: ``0`` .
|
|
1248
|
-
|
|
1249
|
-
Returns:
|
|
1250
|
-
Function, return the cell construct function that will be executed under auto parallel process.
|
|
1251
1265
|
|
|
1252
1266
|
Examples:
|
|
1253
1267
|
>>> import mindspore.nn as nn
|
|
@@ -1265,19 +1279,34 @@ class Cell(Cell_):
|
|
|
1265
1279
|
... def __init__(self):
|
|
1266
1280
|
... self.block1 = Block()
|
|
1267
1281
|
... self.block2 = Block()
|
|
1268
|
-
... self.
|
|
1269
|
-
... parameter_plan={'self.block2.shard.dense1.weight': (4, 1)})
|
|
1282
|
+
... self.block2.shard(in_strategy=((2, 1),), parameter_plan={'self.block2.dense1.weight': (4, 1)})
|
|
1270
1283
|
... def construct(self, x):
|
|
1271
1284
|
... x = self.block1(x)
|
|
1272
|
-
... x = self.
|
|
1285
|
+
... x = self.block2(x)
|
|
1273
1286
|
... return x
|
|
1274
1287
|
"""
|
|
1275
1288
|
if ms.communication.management.get_group_size() == 1:
|
|
1276
|
-
return
|
|
1289
|
+
return
|
|
1290
|
+
|
|
1277
1291
|
shard_fn = Shard()
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1292
|
+
self._shard_fn = shard_fn(self, in_strategy, out_strategy, parameter_plan)
|
|
1293
|
+
|
|
1294
|
+
if self._in_strategy is not None: # pylint: disable=E0203
|
|
1295
|
+
msg = (
|
|
1296
|
+
"For '%s', 'Shard' has been configured more than once. "
|
|
1297
|
+
"The existing in_strategy is %s and the existing out_strategy is %s. "
|
|
1298
|
+
"The new in_strategy %s and out_strategy %s may not take effect. "
|
|
1299
|
+
"It is recommended to configure 'Shard' only once."
|
|
1300
|
+
) % (
|
|
1301
|
+
self._cell_tag,
|
|
1302
|
+
self._in_strategy, # pylint: disable=E0203
|
|
1303
|
+
self._out_strategy, # pylint: disable=E0203
|
|
1304
|
+
shard_fn.in_strategy,
|
|
1305
|
+
shard_fn.out_strategy,
|
|
1306
|
+
)
|
|
1307
|
+
logger.warning(msg)
|
|
1308
|
+
self._in_strategy = shard_fn.in_strategy
|
|
1309
|
+
self._out_strategy = shard_fn.out_strategy
|
|
1281
1310
|
|
|
1282
1311
|
def _init_check(self):
|
|
1283
1312
|
for param in self.get_parameters(expand=False):
|
|
@@ -1286,9 +1315,13 @@ class Cell(Cell_):
|
|
|
1286
1315
|
self._init_flag = True
|
|
1287
1316
|
|
|
1288
1317
|
def _self_check(self):
|
|
1289
|
-
|
|
1290
|
-
self.
|
|
1291
|
-
|
|
1318
|
+
try:
|
|
1319
|
+
if not self._is_check_and_refresh: # pylint: disable=E0203
|
|
1320
|
+
self.check_names_and_refresh_name()
|
|
1321
|
+
self._is_check_and_refresh = True
|
|
1322
|
+
except AttributeError as e:
|
|
1323
|
+
raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
|
|
1324
|
+
f"Please use 'super().__init__()'.") from e
|
|
1292
1325
|
|
|
1293
1326
|
def _predict(self, *args, **kwargs):
|
|
1294
1327
|
'''Graph executor for predict'''
|
|
@@ -1309,6 +1342,7 @@ class Cell(Cell_):
|
|
|
1309
1342
|
def __call__(self, *args, **kwargs):
|
|
1310
1343
|
# Run in Graph mode.
|
|
1311
1344
|
if context._get_mode() == context.GRAPH_MODE and os.getenv("MS_JIT") != '0':
|
|
1345
|
+
self._compiled = True
|
|
1312
1346
|
if kwargs:
|
|
1313
1347
|
bound_arguments = self._construct_sig.bind(*args, **kwargs)
|
|
1314
1348
|
bound_arguments.apply_defaults()
|
|
@@ -1319,11 +1353,8 @@ class Cell(Cell_):
|
|
|
1319
1353
|
if predict_compiled:
|
|
1320
1354
|
return res
|
|
1321
1355
|
self._check_construct_args(*args)
|
|
1322
|
-
|
|
1323
|
-
if self._hook_fn_registered():
|
|
1324
|
-
logger.warning(f"For 'Cell', it's not support hook function in graph mode. If you want to use hook "
|
|
1325
|
-
f"function, please use context.set_context to set pynative mode.")
|
|
1326
1356
|
self._self_check()
|
|
1357
|
+
self.__compile_cell_hook__ = True
|
|
1327
1358
|
out = self.compile_and_run(*args, **kwargs)
|
|
1328
1359
|
return out
|
|
1329
1360
|
|
|
@@ -1421,16 +1452,7 @@ class Cell(Cell_):
|
|
|
1421
1452
|
exist_names.add(item.name)
|
|
1422
1453
|
self.insert_param_to_cell(item.name, item, check_name_contain_dot=False)
|
|
1423
1454
|
|
|
1424
|
-
|
|
1425
|
-
if name in self.__dict__:
|
|
1426
|
-
del self.__dict__[name]
|
|
1427
|
-
params = self.__dict__.get('_params')
|
|
1428
|
-
if name in params:
|
|
1429
|
-
del params[name]
|
|
1430
|
-
params_list = self.__dict__.get('_params_list')
|
|
1431
|
-
params_list[name] = value
|
|
1432
|
-
else:
|
|
1433
|
-
object.__setattr__(self, name, value)
|
|
1455
|
+
object.__setattr__(self, name, value)
|
|
1434
1456
|
|
|
1435
1457
|
def _set_attr_for_parameter_in_list_or_tuple(self, name, value):
|
|
1436
1458
|
"""Set attr for parameter in list or tuple."""
|
|
@@ -1609,8 +1631,6 @@ class Cell(Cell_):
|
|
|
1609
1631
|
_pynative_executor.set_dynamic_input(self, *self._dynamic_shape_inputs)
|
|
1610
1632
|
else:
|
|
1611
1633
|
self._check_construct_args(*inputs)
|
|
1612
|
-
# TODO(tronzhang): It may error for no actually args here. So just set in fullmode,
|
|
1613
|
-
# which means that incremental mode is lacking dynamic input.
|
|
1614
1634
|
else:
|
|
1615
1635
|
self._dynamic_shape_inputs = _process_dyn_args(self.construct, kwargs)
|
|
1616
1636
|
|
|
@@ -1699,6 +1719,7 @@ class Cell(Cell_):
|
|
|
1699
1719
|
_init_auto_parallel_context(self)
|
|
1700
1720
|
compile_args = self._get_compile_args(args)
|
|
1701
1721
|
self._has_mutable_args_list = _get_mutable_flags(compile_args)
|
|
1722
|
+
_cell_graph_executor.set_real_args(args, kwargs)
|
|
1702
1723
|
_cell_graph_executor.compile(self, *compile_args, phase=self.phase,
|
|
1703
1724
|
jit_config_dict=self._jit_config_dict, **kwargs)
|
|
1704
1725
|
_clear_auto_parallel_context(self)
|
|
@@ -2581,23 +2602,7 @@ class Cell(Cell_):
|
|
|
2581
2602
|
else:
|
|
2582
2603
|
self._jit_config_dict = jit_config.jit_config_dict
|
|
2583
2604
|
|
|
2584
|
-
|
|
2585
|
-
"""
|
|
2586
|
-
Reset data for weight parameters so that they are using contiguous memory chunks grouped by data type.
|
|
2587
|
-
|
|
2588
|
-
Note:
|
|
2589
|
-
By default, parameters with same data type will using a single contiguous memory chunk. but for
|
|
2590
|
-
some models with huge number of parameters, splitting a large memory chunk into several smaller
|
|
2591
|
-
memory chunks has the potential for performance gains, if this is the case, we can use 'fusion_size'
|
|
2592
|
-
to limit the maximum memory chunk size.
|
|
2593
|
-
|
|
2594
|
-
Args:
|
|
2595
|
-
fusion_size (int): Maximum memory chunk size in bytes, ``0`` for unlimited. Default: ``0`` .
|
|
2596
|
-
"""
|
|
2597
|
-
if fusion_size < 0:
|
|
2598
|
-
raise ValueError(f"Negative 'fusion_size' {fusion_size} is invalid.")
|
|
2599
|
-
Tensor._flatten_tensors(self.trainable_params(), fusion_size) # pylint: disable=W0212
|
|
2600
|
-
|
|
2605
|
+
@jit_forbidden_register
|
|
2601
2606
|
def register_forward_pre_hook(self, hook_fn, with_kwargs=False):
|
|
2602
2607
|
"""
|
|
2603
2608
|
Register forward pre hook function for Cell object.
|
|
@@ -2617,7 +2622,6 @@ class Cell(Cell_):
|
|
|
2617
2622
|
`with_kwargs` is ``True`` .
|
|
2618
2623
|
|
|
2619
2624
|
Note:
|
|
2620
|
-
- The feature does not take effect in graph mode or in PyNative mode with functions decorated by jit.
|
|
2621
2625
|
- The `hook_fn` can modify the forward inputs by returning new inputs. If `with_kwargs` is ``Flase`` , a
|
|
2622
2626
|
single value (whick will be wrapped into a tuple unless already a tuple) or a tuple of args should be
|
|
2623
2627
|
returned. If `with_kwargs` is ``True`` , both `args` and `kwargs` should be returned.
|
|
@@ -2668,15 +2672,15 @@ class Cell(Cell_):
|
|
|
2668
2672
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
|
|
2669
2673
|
value= [ 2.00000000e+00]))
|
|
2670
2674
|
"""
|
|
2671
|
-
if context._get_mode() == context.GRAPH_MODE:
|
|
2672
|
-
return HookHandle()
|
|
2673
2675
|
check_hook_fn(hook_fn)
|
|
2674
2676
|
handle = HookHandle(self._forward_pre_hook, extra_dict=self._forward_pre_hook_with_kwargs)
|
|
2675
2677
|
self._forward_pre_hook[handle.handle_id] = hook_fn
|
|
2676
2678
|
if with_kwargs:
|
|
2677
2679
|
self._forward_pre_hook_with_kwargs[handle.handle_id] = True
|
|
2680
|
+
_update_hook_version()
|
|
2678
2681
|
return handle
|
|
2679
2682
|
|
|
2683
|
+
@jit_forbidden_register
|
|
2680
2684
|
def _run_forward_pre_hook(self, args, kwargs):
|
|
2681
2685
|
"""
|
|
2682
2686
|
Running forward pre hook function registered on Cell object.
|
|
@@ -2700,6 +2704,35 @@ class Cell(Cell_):
|
|
|
2700
2704
|
args = ret
|
|
2701
2705
|
return args, kwargs
|
|
2702
2706
|
|
|
2707
|
+
def _jit_forward_pre_hook(self, inputs):
|
|
2708
|
+
"""
|
|
2709
|
+
Compile forward pre hook function registered on Cell object.
|
|
2710
|
+
|
|
2711
|
+
Args:
|
|
2712
|
+
inputs: The input objects of cell object.
|
|
2713
|
+
|
|
2714
|
+
Returns:
|
|
2715
|
+
- **outputs** - New input objects or none.
|
|
2716
|
+
|
|
2717
|
+
Supported Platforms:
|
|
2718
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2719
|
+
"""
|
|
2720
|
+
forward_pre_hook_inputs = inputs
|
|
2721
|
+
for fn in self._forward_pre_hook.values():
|
|
2722
|
+
ret = fn(self, forward_pre_hook_inputs)
|
|
2723
|
+
if ret is not None:
|
|
2724
|
+
if not isinstance(ret, tuple):
|
|
2725
|
+
forward_pre_hook_inputs = (ret,)
|
|
2726
|
+
else:
|
|
2727
|
+
forward_pre_hook_inputs = ret
|
|
2728
|
+
|
|
2729
|
+
if len(forward_pre_hook_inputs) != len(inputs):
|
|
2730
|
+
raise TypeError(
|
|
2731
|
+
"The forward pre hook return value size is {} not equal to input size {}".format(
|
|
2732
|
+
len(forward_pre_hook_inputs), len(inputs)))
|
|
2733
|
+
return forward_pre_hook_inputs
|
|
2734
|
+
|
|
2735
|
+
@jit_forbidden_register
|
|
2703
2736
|
def register_forward_hook(self, hook_fn, with_kwargs=False):
|
|
2704
2737
|
"""
|
|
2705
2738
|
Register forward hook function for Cell object.
|
|
@@ -2720,7 +2753,6 @@ class Cell(Cell_):
|
|
|
2720
2753
|
- `output`: Output generated by the `construct` function.
|
|
2721
2754
|
|
|
2722
2755
|
Note:
|
|
2723
|
-
- The feature does not take effect in graph mode or in PyNative mode with functions decorated by jit.
|
|
2724
2756
|
- The `hook_fn` can modify the forward outputs by returning new outputs.
|
|
2725
2757
|
- In order to prevent running failed when switching to graph mode, it is not recommended to call it in the
|
|
2726
2758
|
`construct` function of Cell object.
|
|
@@ -2773,15 +2805,44 @@ class Cell(Cell_):
|
|
|
2773
2805
|
"""
|
|
2774
2806
|
if self.has_bprop:
|
|
2775
2807
|
return HookHandle()
|
|
2776
|
-
if context._get_mode() == context.GRAPH_MODE:
|
|
2777
|
-
return HookHandle()
|
|
2778
2808
|
check_hook_fn(hook_fn)
|
|
2779
2809
|
handle = HookHandle(self._forward_hook, extra_dict=self._forward_hook_with_kwargs)
|
|
2780
2810
|
self._forward_hook[handle.handle_id] = hook_fn
|
|
2781
2811
|
if with_kwargs:
|
|
2782
2812
|
self._forward_hook_with_kwargs[handle.handle_id] = True
|
|
2813
|
+
_update_hook_version()
|
|
2783
2814
|
return handle
|
|
2784
2815
|
|
|
2816
|
+
def _jit_forward_hook(self, inputs, output):
|
|
2817
|
+
"""
|
|
2818
|
+
Compile forward hook function registered on Cell object.
|
|
2819
|
+
|
|
2820
|
+
Args:
|
|
2821
|
+
inputs: The input objects of Cell object.
|
|
2822
|
+
output: The output object of Cell object.
|
|
2823
|
+
|
|
2824
|
+
Returns:
|
|
2825
|
+
- **output** - New output object or none.
|
|
2826
|
+
|
|
2827
|
+
Supported Platforms:
|
|
2828
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2829
|
+
"""
|
|
2830
|
+
forward_hook_output = output
|
|
2831
|
+
for fn in self._forward_hook.values():
|
|
2832
|
+
ret = fn(self, inputs, forward_hook_output)
|
|
2833
|
+
if ret is not None:
|
|
2834
|
+
forward_hook_output = ret
|
|
2835
|
+
|
|
2836
|
+
if isinstance(output, tuple):
|
|
2837
|
+
if not isinstance(forward_hook_output, tuple):
|
|
2838
|
+
forward_hook_output = (forward_hook_output,)
|
|
2839
|
+
if len(forward_hook_output) != len(output):
|
|
2840
|
+
raise TypeError(
|
|
2841
|
+
"The forward hook return value size is {} not equal to output size {}".format(
|
|
2842
|
+
len(forward_hook_output), len(output)))
|
|
2843
|
+
return forward_hook_output
|
|
2844
|
+
|
|
2845
|
+
@jit_forbidden_register
|
|
2785
2846
|
def _run_forward_hook(self, args, kwargs, output):
|
|
2786
2847
|
"""
|
|
2787
2848
|
Running forward hook function registered on Cell object.
|
|
@@ -2795,12 +2856,12 @@ class Cell(Cell_):
|
|
|
2795
2856
|
output = ret
|
|
2796
2857
|
return output
|
|
2797
2858
|
|
|
2859
|
+
@jit_forbidden_register
|
|
2798
2860
|
def register_backward_pre_hook(self, hook_fn):
|
|
2799
2861
|
"""
|
|
2800
2862
|
Register the backward pre hook function.
|
|
2801
2863
|
|
|
2802
2864
|
Note:
|
|
2803
|
-
- The `register_backward_pre_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
|
|
2804
2865
|
- The 'hook_fn' must be defined as the following code.
|
|
2805
2866
|
`cell` is the Cell object. `grad_output` is the gradient passed to the Cell.
|
|
2806
2867
|
- The 'hook_fn' should have the following signature:
|
|
@@ -2849,44 +2910,17 @@ class Cell(Cell_):
|
|
|
2849
2910
|
>>> print(output)
|
|
2850
2911
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
|
|
2851
2912
|
"""
|
|
2852
|
-
if context._get_mode() == context.GRAPH_MODE:
|
|
2853
|
-
return HookHandle()
|
|
2854
2913
|
check_hook_fn(hook_fn)
|
|
2855
|
-
handle = HookHandle(self._backward_pre_hook)
|
|
2914
|
+
handle = HookHandle(self._backward_pre_hook, extra_dict=None)
|
|
2856
2915
|
self._backward_pre_hook[handle.handle_id] = hook_fn
|
|
2857
|
-
if self._cell_backward_pre_hook is None:
|
|
2916
|
+
if self._cell_backward_pre_hook is None: # pylint: disable=E0203
|
|
2858
2917
|
# Generate a CellBackwardHook prim, and add function for it
|
|
2859
2918
|
self._cell_backward_pre_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")",
|
|
2860
2919
|
self, self._backward_pre_hook)
|
|
2861
2920
|
self._cell_backward_pre_hook.register_backward_pre_hook()
|
|
2921
|
+
_update_hook_version()
|
|
2862
2922
|
return handle
|
|
2863
2923
|
|
|
2864
|
-
def _run_backward_pre_hook(self, outputs):
|
|
2865
|
-
"""
|
|
2866
|
-
Running backward pre hook function registered on Cell object.
|
|
2867
|
-
|
|
2868
|
-
Args:
|
|
2869
|
-
outputs: The output objects of cell object.
|
|
2870
|
-
|
|
2871
|
-
Returns:
|
|
2872
|
-
- **outputs** - New backward gradient or None.
|
|
2873
|
-
|
|
2874
|
-
Supported Platforms:
|
|
2875
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
2876
|
-
"""
|
|
2877
|
-
if isinstance(outputs, tuple):
|
|
2878
|
-
ret = self._cell_backward_pre_hook(*outputs)
|
|
2879
|
-
else:
|
|
2880
|
-
ret = self._cell_backward_pre_hook(outputs)
|
|
2881
|
-
if isinstance(outputs, tuple):
|
|
2882
|
-
if len(outputs) == 1:
|
|
2883
|
-
ret = (ret,)
|
|
2884
|
-
if len(ret) != len(outputs):
|
|
2885
|
-
raise TypeError(
|
|
2886
|
-
"The backward pre hook return value size is {} not equal to output size {}".format(
|
|
2887
|
-
len(ret), len(outputs)))
|
|
2888
|
-
return ret
|
|
2889
|
-
|
|
2890
2924
|
def get_extra_state(self) -> Any:
|
|
2891
2925
|
"""Return any extra state to include in the cell's state_dict.
|
|
2892
2926
|
|
|
@@ -2939,9 +2973,8 @@ class Cell(Cell_):
|
|
|
2939
2973
|
A handle that can be used to remove the added hook by calling
|
|
2940
2974
|
`handle.remove()`.
|
|
2941
2975
|
"""
|
|
2942
|
-
|
|
2943
|
-
handle =
|
|
2944
|
-
self._state_dict_hooks[handle.id] = hook
|
|
2976
|
+
handle = HookHandle(self._state_dict_hooks)
|
|
2977
|
+
self._state_dict_hooks[handle.handle_id] = hook
|
|
2945
2978
|
return handle
|
|
2946
2979
|
|
|
2947
2980
|
@jit_forbidden_register
|
|
@@ -2987,9 +3020,8 @@ class Cell(Cell_):
|
|
|
2987
3020
|
>>> print("extra_param" in net_state_dict)
|
|
2988
3021
|
True
|
|
2989
3022
|
"""
|
|
2990
|
-
|
|
2991
|
-
handle =
|
|
2992
|
-
self._state_dict_pre_hooks[handle.id] = hook
|
|
3023
|
+
handle = HookHandle(self._state_dict_pre_hooks)
|
|
3024
|
+
self._state_dict_pre_hooks[handle.handle_id] = hook
|
|
2993
3025
|
return handle
|
|
2994
3026
|
|
|
2995
3027
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
@@ -3135,9 +3167,8 @@ class Cell(Cell_):
|
|
|
3135
3167
|
A handle that can be used to remove the added hook by calling
|
|
3136
3168
|
`handle.remove()`.
|
|
3137
3169
|
"""
|
|
3138
|
-
|
|
3139
|
-
handle =
|
|
3140
|
-
self._load_state_dict_pre_hooks[handle.id] = hook
|
|
3170
|
+
handle = HookHandle(self._load_state_dict_pre_hooks)
|
|
3171
|
+
self._load_state_dict_pre_hooks[handle.handle_id] = hook
|
|
3141
3172
|
return handle
|
|
3142
3173
|
|
|
3143
3174
|
@jit_forbidden_register
|
|
@@ -3169,9 +3200,8 @@ class Cell(Cell_):
|
|
|
3169
3200
|
A handle that can be used to remove the added hook by calling
|
|
3170
3201
|
`handle.remove()`.
|
|
3171
3202
|
"""
|
|
3172
|
-
|
|
3173
|
-
handle =
|
|
3174
|
-
self._load_state_dict_post_hooks[handle.id] = hook
|
|
3203
|
+
handle = HookHandle(self._load_state_dict_post_hooks)
|
|
3204
|
+
self._load_state_dict_post_hooks[handle.handle_id] = hook
|
|
3175
3205
|
return handle
|
|
3176
3206
|
|
|
3177
3207
|
def _load_from_state_dict(
|
|
@@ -3407,12 +3437,12 @@ class Cell(Cell_):
|
|
|
3407
3437
|
)
|
|
3408
3438
|
return _IncompatibleKeys(missing_keys, unexpected_keys)
|
|
3409
3439
|
|
|
3440
|
+
@jit_forbidden_register
|
|
3410
3441
|
def register_backward_hook(self, hook_fn):
|
|
3411
3442
|
"""
|
|
3412
3443
|
Register the backward hook function.
|
|
3413
3444
|
|
|
3414
3445
|
Note:
|
|
3415
|
-
- The `register_backward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
|
|
3416
3446
|
- The 'hook_fn' must be defined as the following code.
|
|
3417
3447
|
`cell` is the registered Cell object. `grad_input` is the gradient computed and passed to
|
|
3418
3448
|
the next Cell or primitive, which can be return a new gradient or None. `grad_output` is the gradient
|
|
@@ -3464,83 +3494,17 @@ class Cell(Cell_):
|
|
|
3464
3494
|
>>> print(output)
|
|
3465
3495
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
|
|
3466
3496
|
"""
|
|
3467
|
-
if context._get_mode() == context.GRAPH_MODE:
|
|
3468
|
-
return HookHandle()
|
|
3469
3497
|
check_hook_fn(hook_fn)
|
|
3470
|
-
handle = HookHandle(self._backward_hook)
|
|
3498
|
+
handle = HookHandle(self._backward_hook, extra_dict=None)
|
|
3471
3499
|
self._backward_hook[handle.handle_id] = hook_fn
|
|
3472
|
-
if self._cell_backward_hook is None:
|
|
3500
|
+
if self._cell_backward_hook is None: # pylint: disable=E0203
|
|
3473
3501
|
# Generate a CellBackwardHook prim, and add function for it
|
|
3474
3502
|
self._cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")",
|
|
3475
3503
|
self, self._backward_hook)
|
|
3476
3504
|
self._cell_backward_hook.register_backward_hook()
|
|
3505
|
+
_update_hook_version()
|
|
3477
3506
|
return handle
|
|
3478
3507
|
|
|
3479
|
-
def _backward_hook_construct(self, *inputs, **kwargs):
|
|
3480
|
-
"""
|
|
3481
|
-
Backward hook construct method to replace original construct method.
|
|
3482
|
-
|
|
3483
|
-
Args:
|
|
3484
|
-
inputs: The input objects of Cell object.
|
|
3485
|
-
kwargs (dict): Dictionary of variable keyword parameters.
|
|
3486
|
-
|
|
3487
|
-
Returns:
|
|
3488
|
-
- **outputs** - The output objects of Cell object.
|
|
3489
|
-
|
|
3490
|
-
Supported Platforms:
|
|
3491
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
3492
|
-
"""
|
|
3493
|
-
# cell_backward_hook has CellBackwardHook op, so keep input args as they are.
|
|
3494
|
-
outputs = self._cell_backward_hook(*inputs)
|
|
3495
|
-
# If the inputs have more than two args, the outputs will also have more than two args and will be wrapped into
|
|
3496
|
-
# a tuple, so need to do unwrapping. If inputs is empty, we also need to unwrap it.
|
|
3497
|
-
# Because when output of runop method is one, it will not wrap a tuple, we need not unwrap it.
|
|
3498
|
-
is_need_unwrap = False
|
|
3499
|
-
if isinstance(outputs, tuple) and len(inputs) != 1:
|
|
3500
|
-
is_need_unwrap = True
|
|
3501
|
-
|
|
3502
|
-
if self._recompute_cell is not None:
|
|
3503
|
-
if is_need_unwrap:
|
|
3504
|
-
outputs = self._recompute_cell(*outputs, **kwargs)
|
|
3505
|
-
else:
|
|
3506
|
-
outputs = self._recompute_cell(outputs, **kwargs)
|
|
3507
|
-
elif self.has_bprop:
|
|
3508
|
-
if is_need_unwrap:
|
|
3509
|
-
outputs = self._call_custom_bprop(*outputs, **kwargs)
|
|
3510
|
-
else:
|
|
3511
|
-
outputs = self._call_custom_bprop(outputs, **kwargs)
|
|
3512
|
-
else:
|
|
3513
|
-
if is_need_unwrap:
|
|
3514
|
-
outputs = self.construct(*outputs, **kwargs)
|
|
3515
|
-
else:
|
|
3516
|
-
outputs = self.construct(outputs, **kwargs)
|
|
3517
|
-
if isinstance(outputs, tuple):
|
|
3518
|
-
new_outputs = self._cell_backward_hook(*outputs)
|
|
3519
|
-
else:
|
|
3520
|
-
new_outputs = self._cell_backward_hook(outputs)
|
|
3521
|
-
# if outputs is (X,) and new_outpus is X
|
|
3522
|
-
if isinstance(outputs, tuple) and len(outputs) == 1:
|
|
3523
|
-
new_outputs = (new_outputs,)
|
|
3524
|
-
return new_outputs
|
|
3525
|
-
|
|
3526
|
-
def set_param_ps(self, recurse=True, init_in_server=False):
|
|
3527
|
-
"""
|
|
3528
|
-
Set whether the trainable parameters are updated by parameter server and whether the
|
|
3529
|
-
trainable parameters are initialized on server.
|
|
3530
|
-
|
|
3531
|
-
Note:
|
|
3532
|
-
It only works when a running task is in the parameter server mode.
|
|
3533
|
-
It is only supported in graph mode.
|
|
3534
|
-
|
|
3535
|
-
Args:
|
|
3536
|
-
recurse (bool): Whether sets the trainable parameters of subcells. Default: ``True`` .
|
|
3537
|
-
init_in_server (bool): Whether trainable parameters updated by parameter server are
|
|
3538
|
-
initialized on server. Default: ``False`` .
|
|
3539
|
-
"""
|
|
3540
|
-
params = self.trainable_params(recurse)
|
|
3541
|
-
for param in params:
|
|
3542
|
-
param.set_param_ps(init_in_server)
|
|
3543
|
-
|
|
3544
3508
|
def set_comm_fusion(self, fusion_type, recurse=True):
|
|
3545
3509
|
"""
|
|
3546
3510
|
Set `comm_fusion` for all the parameters in this cell. Please refer to the description of
|
|
@@ -3601,7 +3565,7 @@ class Cell(Cell_):
|
|
|
3601
3565
|
"""
|
|
3602
3566
|
Validator.check_bool(mode)
|
|
3603
3567
|
Validator.check_bool(output_recompute)
|
|
3604
|
-
if not self._has_config_recompute:
|
|
3568
|
+
if not self._has_config_recompute: # pylint: disable=E0203
|
|
3605
3569
|
self._has_config_recompute = True
|
|
3606
3570
|
else:
|
|
3607
3571
|
logger.info("The recompute interface can be configured only once."
|
|
@@ -3644,8 +3608,7 @@ class Cell(Cell_):
|
|
|
3644
3608
|
introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode.
|
|
3645
3609
|
Default: ``False`` .
|
|
3646
3610
|
"""
|
|
3647
|
-
|
|
3648
|
-
self._recompute_cell = recompute_registry.get()(self.construct)
|
|
3611
|
+
self._recompute_cell = recompute_registry.get()(self.construct)
|
|
3649
3612
|
self._recompute()
|
|
3650
3613
|
if 'mp_comm_recompute' in kwargs.keys():
|
|
3651
3614
|
self._mp_comm_recompute(kwargs.get('mp_comm_recompute', False))
|
|
@@ -3662,35 +3625,6 @@ class Cell(Cell_):
|
|
|
3662
3625
|
"the key kwargs must be 'mp_comm_recompute', "
|
|
3663
3626
|
"'parallel_optimizer_comm_recompute', 'recompute_slice_activation'" % key)
|
|
3664
3627
|
|
|
3665
|
-
def place(self, role, rank_id):
|
|
3666
|
-
"""
|
|
3667
|
-
Set the label for all operators in this cell.
|
|
3668
|
-
This label tells MindSpore compiler on which process this cell should be launched.
|
|
3669
|
-
And each process's identical label consists of input `role` and `rank_id`.
|
|
3670
|
-
So by setting different cells with different labels, which will be launched on different processes,
|
|
3671
|
-
users can launch a distributed training or predicting job.
|
|
3672
|
-
|
|
3673
|
-
Note:
|
|
3674
|
-
- This method is effective only after
|
|
3675
|
-
`mindspore.communication.init()` is called for dynamic cluster building.
|
|
3676
|
-
|
|
3677
|
-
Args:
|
|
3678
|
-
role (str): The role of the process on which this cell will be launched.
|
|
3679
|
-
Only 'MS_WORKER' is supported for now.
|
|
3680
|
-
rank_id (int): The rank id of the process on which this cell will be launched.
|
|
3681
|
-
The rank is unique in processes with the same role.
|
|
3682
|
-
|
|
3683
|
-
Examples:
|
|
3684
|
-
>>> from mindspore import context
|
|
3685
|
-
>>> import mindspore.nn as nn
|
|
3686
|
-
>>> context.set_context(mode=context.GRAPH_MODE)
|
|
3687
|
-
>>> fc = nn.Dense(2, 3)
|
|
3688
|
-
>>> fc.place('MS_WORKER', 0)
|
|
3689
|
-
"""
|
|
3690
|
-
all_ops = self._get_prims_recursively()
|
|
3691
|
-
for op in all_ops:
|
|
3692
|
-
op.place(role, rank_id)
|
|
3693
|
-
|
|
3694
3628
|
def _get_attr_from_cell(self, network):
|
|
3695
3629
|
if not isinstance(network, Cell):
|
|
3696
3630
|
return
|
|
@@ -3705,6 +3639,64 @@ class Cell(Cell_):
|
|
|
3705
3639
|
"""
|
|
3706
3640
|
self._jit_graph_name = key
|
|
3707
3641
|
|
|
3642
|
+
def _jit_backward_pre_hook(self, grad_output):
|
|
3643
|
+
new_grad_output = grad_output
|
|
3644
|
+
if not isinstance(grad_output, tuple):
|
|
3645
|
+
new_grad_output = (grad_output,)
|
|
3646
|
+
|
|
3647
|
+
for fn in self._backward_pre_hook.values():
|
|
3648
|
+
ret = fn(self, new_grad_output)
|
|
3649
|
+
if ret is not None:
|
|
3650
|
+
if not isinstance(ret, tuple):
|
|
3651
|
+
output = (ret,)
|
|
3652
|
+
else:
|
|
3653
|
+
output = ret
|
|
3654
|
+
else:
|
|
3655
|
+
output = ops.Depend()(new_grad_output, ret)
|
|
3656
|
+
new_grad_output = output
|
|
3657
|
+
|
|
3658
|
+
if not isinstance(grad_output, tuple):
|
|
3659
|
+
if len(new_grad_output) == 1:
|
|
3660
|
+
return new_grad_output[0]
|
|
3661
|
+
raise TypeError(
|
|
3662
|
+
"The backward pre hook return value size is {} not equal to input size 1".format(
|
|
3663
|
+
len(new_grad_output)))
|
|
3664
|
+
|
|
3665
|
+
if len(new_grad_output) != len(grad_output):
|
|
3666
|
+
raise TypeError(
|
|
3667
|
+
"The backward pre hook return value size is {} not equal to input size {}".format(
|
|
3668
|
+
len(new_grad_output), len(grad_output)))
|
|
3669
|
+
|
|
3670
|
+
return new_grad_output
|
|
3671
|
+
|
|
3672
|
+
def _jit_backward_hook(self, grad_input, grad_output):
|
|
3673
|
+
backward_hook_input = grad_input
|
|
3674
|
+
backward_hook_output = grad_output
|
|
3675
|
+
if not isinstance(grad_input, tuple):
|
|
3676
|
+
backward_hook_input = (grad_input,)
|
|
3677
|
+
if not isinstance(grad_output, tuple):
|
|
3678
|
+
backward_hook_output = (grad_output,)
|
|
3679
|
+
|
|
3680
|
+
for fn in self._backward_hook.values():
|
|
3681
|
+
ret = fn(self, backward_hook_input, backward_hook_output)
|
|
3682
|
+
if ret is not None:
|
|
3683
|
+
if not isinstance(ret, tuple):
|
|
3684
|
+
output = (ret,)
|
|
3685
|
+
else:
|
|
3686
|
+
output = ret
|
|
3687
|
+
else:
|
|
3688
|
+
output = ops.Depend()(backward_hook_input, ret)
|
|
3689
|
+
|
|
3690
|
+
backward_hook_input = output
|
|
3691
|
+
|
|
3692
|
+
if not isinstance(grad_input, tuple):
|
|
3693
|
+
return backward_hook_input[0]
|
|
3694
|
+
|
|
3695
|
+
if len(backward_hook_input) != len(grad_input):
|
|
3696
|
+
raise TypeError(
|
|
3697
|
+
"The backward hook return value size is {} not equal to input size {}".format(
|
|
3698
|
+
len(backward_hook_input), len(grad_input)))
|
|
3699
|
+
return backward_hook_input
|
|
3708
3700
|
|
|
3709
3701
|
class GraphCell(Cell):
|
|
3710
3702
|
"""
|