mindspore 2.6.0__cp310-cp310-win_amd64.whl → 2.7.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +2 -2
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +42 -11
- mindspore/_extends/builtin_operations.py +3 -3
- mindspore/{_deprecated → _extends/optimize}/__init__.py +9 -3
- mindspore/_extends/optimize/cell_utils.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +3 -3
- mindspore/_extends/parse/compile_config.py +44 -22
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -2
- mindspore/_extends/parse/parser.py +64 -83
- mindspore/_extends/parse/resources.py +39 -0
- mindspore/_extends/parse/standard_method.py +47 -14
- mindspore/_extends/parse/trope.py +8 -1
- mindspore/_extends/pijit/__init__.py +1 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +2 -5
- mindspore/amp.py +4 -22
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +4 -4
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +43 -12
- mindspore/common/_grad_function.py +2 -1
- mindspore/common/_pijit_context.py +28 -7
- mindspore/common/_stub_tensor.py +1 -209
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +177 -52
- mindspore/common/_utils.py +9 -1
- mindspore/common/api.py +338 -208
- mindspore/common/dtype.py +108 -57
- mindspore/common/dump.py +11 -16
- mindspore/common/dynamic_shape/__init__.py +0 -0
- mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +17 -23
- mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
- mindspore/common/file_system.py +59 -9
- mindspore/common/generator.py +2 -3
- mindspore/common/hook_handle.py +33 -5
- mindspore/common/jit_config.py +1 -1
- mindspore/common/jit_trace.py +84 -105
- mindspore/common/np_dtype.py +3 -3
- mindspore/common/parameter.py +27 -29
- mindspore/common/recompute.py +5 -7
- mindspore/common/sparse_tensor.py +0 -3
- mindspore/common/symbol.py +0 -1
- mindspore/common/tensor.py +84 -133
- mindspore/communication/_comm_helper.py +46 -4
- mindspore/communication/management.py +79 -7
- mindspore/context.py +47 -38
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/transforms.py +1 -1
- mindspore/dataset/core/config.py +38 -4
- mindspore/dataset/engine/datasets.py +350 -322
- mindspore/dataset/engine/datasets_user_defined.py +69 -23
- mindspore/dataset/engine/iterators.py +2 -2
- mindspore/dataset/engine/obs/config_loader.py +2 -2
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +8 -0
- mindspore/dataset/transforms/c_transforms.py +2 -2
- mindspore/dataset/transforms/py_transforms.py +7 -3
- mindspore/dataset/transforms/transforms.py +10 -6
- mindspore/dataset/vision/__init__.py +1 -1
- mindspore/dataset/vision/py_transforms.py +8 -8
- mindspore/dataset/vision/transforms.py +17 -5
- mindspore/dataset/vision/utils.py +632 -21
- mindspore/dataset/vision/validators.py +1 -0
- mindspore/device_context/ascend/device.py +1 -1
- mindspore/device_context/ascend/op_tuning.py +35 -1
- mindspore/device_context/gpu/__init__.py +2 -2
- mindspore/device_context/gpu/device.py +1 -1
- mindspore/device_context/gpu/op_precision.py +4 -2
- mindspore/device_context/gpu/op_tuning.py +6 -3
- mindspore/device_manager.py +16 -9
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +5 -4
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/optim/adadelta.py +13 -20
- mindspore/experimental/optim/adagrad.py +15 -22
- mindspore/experimental/optim/adam.py +17 -24
- mindspore/experimental/optim/adamax.py +14 -22
- mindspore/experimental/optim/adamw.py +28 -34
- mindspore/experimental/optim/asgd.py +15 -25
- mindspore/experimental/optim/lr_scheduler.py +27 -45
- mindspore/experimental/optim/nadam.py +14 -24
- mindspore/experimental/optim/optimizer.py +13 -23
- mindspore/experimental/optim/radam.py +18 -24
- mindspore/experimental/optim/rmsprop.py +14 -25
- mindspore/experimental/optim/rprop.py +15 -26
- mindspore/experimental/optim/sgd.py +9 -19
- mindspore/hal/__init__.py +4 -4
- mindspore/hal/contiguous_tensors_handle.py +2 -2
- mindspore/hal/memory.py +1 -0
- mindspore/include/api/cell.h +65 -5
- mindspore/include/api/cfg.h +24 -7
- mindspore/include/api/context.h +1 -0
- mindspore/include/api/delegate.h +10 -2
- mindspore/include/api/dual_abi_helper.h +100 -19
- mindspore/include/api/graph.h +14 -1
- mindspore/include/api/kernel.h +16 -3
- mindspore/include/api/kernel_api.h +9 -1
- mindspore/include/api/metrics/accuracy.h +9 -0
- mindspore/include/api/model.h +8 -1
- mindspore/include/api/model_group.h +4 -0
- mindspore/include/api/model_parallel_runner.h +2 -0
- mindspore/include/api/status.h +48 -10
- mindspore/include/api/types.h +8 -3
- mindspore/include/c_api/model_c.h +0 -58
- mindspore/include/c_api/tensor_c.h +0 -26
- mindspore/include/dataset/constants.h +9 -0
- mindspore/include/dataset/vision_ascend.h +1 -1
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/cifar10.py +61 -11
- mindspore/mindrecord/tools/cifar10_to_mr.py +5 -0
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mindspore_ops_host.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +4 -44
- mindspore/mint/distributed/__init__.py +5 -0
- mindspore/mint/distributed/distributed.py +425 -19
- mindspore/mint/nn/__init__.py +1 -1
- mindspore/mint/nn/functional.py +53 -6
- mindspore/mint/nn/layer/_functions.py +163 -294
- mindspore/mint/nn/layer/activation.py +8 -6
- mindspore/mint/nn/layer/conv.py +125 -101
- mindspore/mint/nn/layer/normalization.py +11 -25
- mindspore/mint/optim/adam.py +19 -18
- mindspore/mint/optim/adamw.py +14 -8
- mindspore/mint/optim/sgd.py +5 -5
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/cell.py +488 -620
- mindspore/nn/grad/cell_grad.py +11 -12
- mindspore/nn/layer/activation.py +36 -36
- mindspore/nn/layer/basic.py +74 -77
- mindspore/nn/layer/channel_shuffle.py +4 -4
- mindspore/nn/layer/combined.py +4 -2
- mindspore/nn/layer/conv.py +86 -85
- mindspore/nn/layer/dense.py +9 -7
- mindspore/nn/layer/embedding.py +50 -52
- mindspore/nn/layer/image.py +38 -40
- mindspore/nn/layer/math.py +111 -112
- mindspore/nn/layer/normalization.py +56 -44
- mindspore/nn/layer/pooling.py +58 -63
- mindspore/nn/layer/rnn_cells.py +33 -33
- mindspore/nn/layer/rnns.py +56 -56
- mindspore/nn/layer/thor_layer.py +74 -73
- mindspore/nn/layer/transformer.py +11 -1
- mindspore/nn/learning_rate_schedule.py +20 -20
- mindspore/nn/loss/loss.py +79 -81
- mindspore/nn/optim/adam.py +2 -4
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/lamb.py +1 -3
- mindspore/nn/optim/optimizer.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +2 -3
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -1
- mindspore/nn/probability/distribution/poisson.py +2 -1
- mindspore/nn/sparse/sparse.py +3 -3
- mindspore/nn/wrap/cell_wrapper.py +73 -42
- mindspore/nn/wrap/grad_reducer.py +37 -52
- mindspore/nn/wrap/loss_scale.py +72 -74
- mindspore/numpy/array_creations.py +7 -7
- mindspore/numpy/fft.py +1 -1
- mindspore/numpy/math_ops.py +1 -1
- mindspore/numpy/utils_const.py +1 -1
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +51 -13
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/{experimental/es/__init__.py → ops/_op_impl/cpu/joinedstr_op.py} +12 -6
- mindspore/ops/_vmap/vmap_array_ops.py +6 -13
- mindspore/ops/_vmap/vmap_nn_ops.py +8 -16
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +29 -10
- mindspore/ops/auto_generate/gen_extend_func.py +5 -55
- mindspore/ops/auto_generate/gen_ops_def.py +753 -273
- mindspore/ops/auto_generate/gen_ops_prim.py +1687 -958
- mindspore/ops/auto_generate/pyboost_inner_prim.py +31 -1
- mindspore/ops/composite/__init__.py +10 -0
- mindspore/ops/composite/base.py +9 -5
- mindspore/ops/composite/multitype_ops/__init__.py +12 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +132 -108
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/composite/multitype_ops/add_impl.py +70 -2
- mindspore/ops/composite/multitype_ops/div_impl.py +49 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +29 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +11 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +5 -3
- mindspore/ops/composite/multitype_ops/mul_impl.py +49 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +57 -0
- mindspore/ops/composite/multitype_ops/sub_impl.py +34 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +14 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/_add_attr_func.py +11 -6
- mindspore/ops/function/array_func.py +17 -100
- mindspore/ops/function/debug_func.py +8 -5
- mindspore/ops/function/grad/grad_func.py +5 -13
- mindspore/ops/function/math_func.py +65 -399
- mindspore/ops/function/nn_func.py +44 -61
- mindspore/ops/function/other_func.py +4 -1
- mindspore/ops/function/random_func.py +31 -4
- mindspore/ops/functional.py +2 -3
- mindspore/ops/functional_overload.py +486 -18
- mindspore/ops/op_info_register.py +21 -0
- mindspore/ops/operations/__init__.py +5 -2
- mindspore/ops/operations/_custom_ops_utils.py +675 -8
- mindspore/ops/operations/_inner_ops.py +14 -18
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +4 -50
- mindspore/ops/operations/comm_ops.py +186 -41
- mindspore/ops/operations/custom_ops.py +244 -175
- mindspore/ops/operations/debug_ops.py +55 -4
- mindspore/ops/operations/image_ops.py +13 -13
- mindspore/ops/operations/manually_defined/ops_def.py +27 -28
- mindspore/ops/operations/math_ops.py +8 -9
- mindspore/ops/operations/nn_ops.py +6 -7
- mindspore/ops/primitive.py +9 -20
- mindspore/ops/tensor_method.py +52 -11
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +1 -1
- mindspore/ops_generate/api/functional_map_cpp_generator.py +10 -9
- mindspore/ops_generate/api/functions_cc_generator.py +58 -10
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +1 -1
- mindspore/ops_generate/common/base_generator.py +14 -0
- mindspore/ops_generate/common/gen_constants.py +7 -2
- mindspore/ops_generate/common/gen_utils.py +0 -19
- mindspore/ops_generate/common/op_proto.py +11 -4
- mindspore/ops_generate/common/template.py +88 -11
- mindspore/ops_generate/gen_ops.py +1 -1
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +4 -4
- mindspore/ops_generate/op_def/ops_name_h_generator.py +0 -3
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +0 -4
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -2
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +49 -8
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +2 -2
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +31 -16
- mindspore/ops_generate/pyboost/op_template_parser.py +98 -72
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +70 -273
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +14 -6
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +316 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +5 -3
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_internal_functions_cpp_generator.py +76 -0
- mindspore/ops_generate/pyboost/pyboost_internal_functions_h_generator.py +76 -0
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +125 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +4 -3
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +348 -61
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_utils.py +118 -9
- mindspore/ops_generate/tensor_py_cc_generator.py +1 -24
- mindspore/parallel/_auto_parallel_context.py +9 -17
- mindspore/parallel/_cell_wrapper.py +106 -40
- mindspore/parallel/_parallel_serialization.py +4 -3
- mindspore/parallel/_ps_context.py +4 -6
- mindspore/parallel/_tensor.py +167 -12
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/transformer.py +17 -12
- mindspore/parallel/_utils.py +5 -11
- mindspore/parallel/auto_parallel.py +33 -12
- mindspore/parallel/checkpoint_convert.py +3 -3
- mindspore/parallel/checkpoint_transform.py +5 -1
- mindspore/parallel/cluster/process_entity/_api.py +88 -49
- mindspore/parallel/cluster/process_entity/_utils.py +95 -7
- mindspore/parallel/cluster/run.py +48 -7
- mindspore/parallel/function/__init__.py +8 -1
- mindspore/parallel/function/reshard_func.py +7 -6
- mindspore/parallel/nn/__init__.py +15 -2
- mindspore/parallel/nn/parallel_cell_wrapper.py +50 -14
- mindspore/parallel/nn/parallel_grad_reducer.py +7 -14
- mindspore/parallel/shard.py +9 -23
- mindspore/parallel/transform_safetensors.py +468 -174
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -7
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +3 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +3 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +4 -4
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +4 -1
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +2 -1
- mindspore/profiler/analysis/task_manager.py +1 -1
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +5 -1
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +2 -1
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +10 -9
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +43 -23
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +3 -2
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +9 -5
- mindspore/profiler/analysis/viewer/ms_operator_details_viewer.py +132 -0
- mindspore/profiler/common/constant.py +16 -0
- mindspore/profiler/common/msprof_cmd_tool.py +2 -2
- mindspore/profiler/common/path_manager.py +9 -0
- mindspore/profiler/common/profiler_context.py +50 -29
- mindspore/profiler/common/profiler_info.py +0 -16
- mindspore/profiler/common/profiler_meta_data.py +1 -0
- mindspore/profiler/common/profiler_op_analyse.py +239 -0
- mindspore/profiler/common/profiler_output_path.py +23 -8
- mindspore/profiler/common/profiler_parameters.py +128 -35
- mindspore/profiler/dynamic_profile/__init__.py +0 -0
- mindspore/profiler/dynamic_profile/dynamic_monitor_proxy.py +39 -0
- mindspore/profiler/dynamic_profile/dynamic_profiler_config_context.py +666 -0
- mindspore/profiler/dynamic_profile/dynamic_profiler_utils.py +62 -0
- mindspore/profiler/dynamic_profiler.py +374 -338
- mindspore/profiler/envprofiler.py +42 -12
- mindspore/profiler/experimental_config.py +112 -7
- mindspore/profiler/mstx.py +33 -12
- mindspore/profiler/platform/__init__.py +2 -3
- mindspore/profiler/platform/cpu_profiler.py +10 -4
- mindspore/profiler/platform/npu_profiler.py +30 -20
- mindspore/profiler/profiler.py +218 -154
- mindspore/profiler/profiler_action_controller.py +65 -77
- mindspore/profiler/profiler_interface.py +2 -2
- mindspore/profiler/schedule.py +10 -4
- mindspore/rewrite/common/config.py +1 -0
- mindspore/rewrite/common/namer.py +1 -0
- mindspore/rewrite/common/namespace.py +1 -0
- mindspore/rewrite/node/node.py +31 -11
- mindspore/rewrite/parsers/assign_parser.py +1 -1
- mindspore/rewrite/symbol_tree/symbol_tree.py +2 -2
- mindspore/run_check/_check_version.py +7 -10
- mindspore/runtime/__init__.py +8 -6
- mindspore/runtime/event.py +10 -4
- mindspore/runtime/executor.py +87 -45
- mindspore/runtime/memory.py +22 -30
- mindspore/runtime/thread_bind_core.py +299 -165
- mindspore/safeguard/rewrite_obfuscation.py +12 -13
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/_utils.py +9 -5
- mindspore/train/amp.py +43 -23
- mindspore/train/callback/__init__.py +5 -5
- mindspore/train/callback/_callback.py +2 -1
- mindspore/train/callback/_checkpoint.py +4 -14
- mindspore/train/callback/_flops_collector.py +11 -7
- mindspore/train/callback/_landscape.py +0 -1
- mindspore/train/callback/_train_fault_tolerance.py +72 -18
- mindspore/train/data_sink.py +15 -6
- mindspore/train/dataset_helper.py +14 -5
- mindspore/train/model.py +49 -47
- mindspore/train/serialization.py +168 -126
- mindspore/train/summary/summary_record.py +13 -2
- mindspore/train/train_thor/model_thor.py +2 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +3 -2
- mindspore/utils/dryrun.py +0 -6
- mindspore/utils/runtime_execution_order_check.py +162 -78
- mindspore/utils/sdc_detect.py +68 -0
- mindspore/utils/utils.py +14 -17
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/METADATA +5 -4
- {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/RECORD +400 -439
- mindspore/_deprecated/jit.py +0 -198
- mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
- mindspore/communication/_hccl_management.py +0 -297
- mindspore/experimental/es/embedding_service.py +0 -891
- mindspore/experimental/es/embedding_service_layer.py +0 -581
- mindspore/profiler/common/validator/__init__.py +0 -14
- mindspore/profiler/common/validator/validate_path.py +0 -84
- mindspore/profiler/parser/__init__.py +0 -14
- mindspore/profiler/parser/aicpu_data_parser.py +0 -272
- mindspore/profiler/parser/ascend_analysis/__init__.py +0 -14
- mindspore/profiler/parser/ascend_analysis/constant.py +0 -71
- mindspore/profiler/parser/ascend_analysis/file_manager.py +0 -180
- mindspore/profiler/parser/ascend_analysis/function_event.py +0 -185
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +0 -136
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +0 -131
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +0 -104
- mindspore/profiler/parser/ascend_analysis/path_manager.py +0 -313
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +0 -123
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +0 -86
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +0 -75
- mindspore/profiler/parser/ascend_cluster_generator.py +0 -116
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -314
- mindspore/profiler/parser/ascend_flops_generator.py +0 -116
- mindspore/profiler/parser/ascend_fpbp_generator.py +0 -82
- mindspore/profiler/parser/ascend_hccl_generator.py +0 -271
- mindspore/profiler/parser/ascend_integrate_generator.py +0 -42
- mindspore/profiler/parser/ascend_memory_generator.py +0 -185
- mindspore/profiler/parser/ascend_msprof_exporter.py +0 -282
- mindspore/profiler/parser/ascend_msprof_generator.py +0 -187
- mindspore/profiler/parser/ascend_op_generator.py +0 -334
- mindspore/profiler/parser/ascend_steptrace_generator.py +0 -94
- mindspore/profiler/parser/ascend_timeline_generator.py +0 -545
- mindspore/profiler/parser/base_timeline_generator.py +0 -483
- mindspore/profiler/parser/container.py +0 -229
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +0 -697
- mindspore/profiler/parser/flops_parser.py +0 -531
- mindspore/profiler/parser/framework_enum.py +0 -111
- mindspore/profiler/parser/framework_parser.py +0 -464
- mindspore/profiler/parser/framework_struct.py +0 -61
- mindspore/profiler/parser/gpu_analysis/__init__.py +0 -14
- mindspore/profiler/parser/gpu_analysis/function_event.py +0 -44
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +0 -89
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +0 -72
- mindspore/profiler/parser/hccl_parser.py +0 -573
- mindspore/profiler/parser/hwts_log_parser.py +0 -122
- mindspore/profiler/parser/integrator.py +0 -526
- mindspore/profiler/parser/memory_usage_parser.py +0 -277
- mindspore/profiler/parser/minddata_analyzer.py +0 -800
- mindspore/profiler/parser/minddata_parser.py +0 -186
- mindspore/profiler/parser/minddata_pipeline_parser.py +0 -299
- mindspore/profiler/parser/op_intermediate_parser.py +0 -149
- mindspore/profiler/parser/optime_parser.py +0 -250
- mindspore/profiler/parser/profiler_info.py +0 -213
- mindspore/profiler/parser/step_trace_parser.py +0 -666
- mindspore/utils/hooks.py +0 -81
- /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
- {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/WHEEL +0 -0
- {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py
CHANGED
|
@@ -15,6 +15,10 @@
|
|
|
15
15
|
"""cell"""
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
|
|
18
|
+
__all__ = [
|
|
19
|
+
"register_cell_buffer_registration_hook",
|
|
20
|
+
]
|
|
21
|
+
|
|
18
22
|
import inspect
|
|
19
23
|
import os
|
|
20
24
|
import time
|
|
@@ -24,7 +28,6 @@ from collections import OrderedDict, namedtuple
|
|
|
24
28
|
from typing import (
|
|
25
29
|
Dict,
|
|
26
30
|
Optional,
|
|
27
|
-
Set,
|
|
28
31
|
Callable,
|
|
29
32
|
List,
|
|
30
33
|
Tuple,
|
|
@@ -34,36 +37,30 @@ from typing import (
|
|
|
34
37
|
Mapping
|
|
35
38
|
)
|
|
36
39
|
|
|
40
|
+
import weakref
|
|
37
41
|
import mindspore as ms
|
|
42
|
+
import mindspore.ops as ops
|
|
38
43
|
from mindspore._checkparam import args_type_check, check_hook_fn
|
|
39
|
-
from mindspore.common._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic
|
|
44
|
+
from mindspore.common.dynamic_shape._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic
|
|
40
45
|
from mindspore import log as logger
|
|
41
|
-
from mindspore.common.
|
|
42
|
-
from mindspore.common.hook_handle import HookHandle
|
|
43
|
-
from mindspore.context import ParallelMode
|
|
46
|
+
from mindspore.common.hook_handle import HookHandle, _update_hook_version
|
|
44
47
|
from mindspore import context
|
|
45
48
|
from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
|
|
46
49
|
from mindspore import _checkparam as Validator
|
|
47
50
|
from mindspore.common import dtype as mstype
|
|
48
51
|
from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache, \
|
|
49
|
-
_no_grad
|
|
50
|
-
from mindspore.common.api import _convert_python_data
|
|
52
|
+
_no_grad, _get_mutable_flags
|
|
53
|
+
from mindspore.common.api import _convert_python_data
|
|
51
54
|
from mindspore.common.api import _process_dyn_args, _generate_dyn_compile_args
|
|
52
|
-
from mindspore.common.parameter import _Buffer, Parameter, ParameterTuple
|
|
55
|
+
from mindspore.common.parameter import _Buffer, Parameter, ParameterTuple, _is_parameter_generated
|
|
53
56
|
from mindspore.common.tensor import Tensor
|
|
54
|
-
from mindspore.ops.operations import Cast
|
|
55
57
|
from mindspore.ops.primitive import Primitive
|
|
56
58
|
from mindspore.ops.operations import _inner_ops as inner
|
|
57
59
|
from mindspore.parallel.shard import Shard
|
|
58
60
|
from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
|
|
59
61
|
from mindspore._check_jit_forbidden_api import jit_forbidden_register
|
|
60
|
-
from mindspore.common._decorator import deprecated
|
|
61
62
|
from mindspore.common._register_for_recompute import recompute_registry
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
__all__ = [
|
|
65
|
-
"register_cell_buffer_registration_hook",
|
|
66
|
-
]
|
|
63
|
+
from mindspore.common.jit_config import JitConfig
|
|
67
64
|
|
|
68
65
|
_global_buffer_registration_hooks: Dict[int, Callable] = OrderedDict()
|
|
69
66
|
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
|
|
@@ -96,13 +93,11 @@ def register_cell_buffer_registration_hook(hook: Callable[..., None],):
|
|
|
96
93
|
A handle that can be used to remove the added hook by calling
|
|
97
94
|
`handle.remove()`.
|
|
98
95
|
"""
|
|
99
|
-
|
|
100
|
-
handle =
|
|
101
|
-
_global_buffer_registration_hooks[handle.id] = hook
|
|
96
|
+
handle = HookHandle(_global_buffer_registration_hooks)
|
|
97
|
+
_global_buffer_registration_hooks[handle.handle_id] = hook
|
|
102
98
|
return handle
|
|
103
99
|
|
|
104
100
|
|
|
105
|
-
|
|
106
101
|
class Cell(Cell_):
|
|
107
102
|
"""
|
|
108
103
|
The basic building block of neural networks in MindSpore. The model or neural network layer should inherit this
|
|
@@ -160,51 +155,59 @@ class Cell(Cell_):
|
|
|
160
155
|
IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_create_time',
|
|
161
156
|
'_func_graph_flags', '_parameter_layout_dict', '_params_list', '_phase', '_bprop_debug',
|
|
162
157
|
'_forward_pre_hook', '_forward_hook', '_backward_pre_hook', '_backward_hook',
|
|
163
|
-
'_cell_backward_pre_hook', '_cell_backward_hook', '
|
|
164
|
-
'
|
|
165
|
-
'_parameters_forward_hook', '_parameters_backward_hook']
|
|
158
|
+
'_cell_backward_pre_hook', '_cell_backward_hook', '_param_prefix',
|
|
159
|
+
'requires_grad', 'cell_type', '_in_strategy', '_out_strategy']
|
|
166
160
|
total_instance_count = 0
|
|
167
161
|
_buffers: Dict[str, Optional[Tensor]]
|
|
168
|
-
|
|
162
|
+
global_cells = weakref.WeakKeyDictionary()
|
|
163
|
+
_no_auto_lazy_inline = True
|
|
164
|
+
|
|
165
|
+
def __new__(class_, *args, **kwargs):
|
|
166
|
+
# Use class_ to avoid name conflicts with input args and kwargs.
|
|
167
|
+
this = Cell_.__new__(class_, *args, **kwargs)
|
|
168
|
+
if Cell._no_auto_lazy_inline:
|
|
169
|
+
return this
|
|
170
|
+
|
|
171
|
+
Cell.global_cells[this] = (class_, args, kwargs)
|
|
172
|
+
return this
|
|
169
173
|
|
|
170
174
|
def __init__(self, auto_prefix=True, flags=None):
|
|
171
175
|
Cell_.__init__(self, self._cell_tag)
|
|
172
176
|
Cell.total_instance_count += 1
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
self._cells = OrderedDict()
|
|
177
|
+
super().__setattr__("_params", OrderedDict())
|
|
178
|
+
super().__setattr__("_cells", OrderedDict())
|
|
176
179
|
super().__setattr__("_buffers", {})
|
|
177
|
-
super().__setattr__("
|
|
178
|
-
super().__setattr__("
|
|
179
|
-
|
|
180
|
-
super().__setattr__("
|
|
181
|
-
super().__setattr__("
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
self.phase_cache = dict()
|
|
180
|
+
super().__setattr__("_params_list", OrderedDict())
|
|
181
|
+
super().__setattr__("_primitives", OrderedDict())
|
|
182
|
+
|
|
183
|
+
super().__setattr__("_lazy_non_persistent_buffers_set", None)
|
|
184
|
+
super().__setattr__("_lazy_state_dict_hooks", None)
|
|
185
|
+
super().__setattr__("_lazy_state_dict_pre_hooks", None)
|
|
186
|
+
super().__setattr__("_lazy_load_state_dict_pre_hooks", None)
|
|
187
|
+
super().__setattr__("_lazy_load_state_dict_post_hooks", None)
|
|
188
|
+
super().__setattr__("training", False)
|
|
189
|
+
super().__setattr__("requires_grad", False)
|
|
190
|
+
super().__setattr__("is_top_cell", False)
|
|
191
|
+
super().__setattr__("_param_prefix", '')
|
|
192
|
+
super().__setattr__("_auto_prefix", auto_prefix)
|
|
193
|
+
super().__setattr__("_scope", None)
|
|
194
|
+
super().__setattr__("_phase", 'train')
|
|
195
|
+
super().__setattr__("_parameter_layout_dict", None)
|
|
196
|
+
super().__setattr__("_parallel_parameter_name_list", None)
|
|
197
|
+
super().__setattr__("_parallel_parameter_merge_net_dict", None)
|
|
198
|
+
super().__setattr__("_create_time", int(time.time() * 1e9))
|
|
199
|
+
super().__setattr__("arguments_key", "")
|
|
200
|
+
super().__setattr__("_compile_cache", None)
|
|
201
|
+
super().__setattr__("_phase_cache", None)
|
|
200
202
|
cells_compile_cache[id(self)] = self.compile_cache
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
203
|
+
super().__setattr__("_id", 1)
|
|
204
|
+
super().__setattr__("_exist_objs", None)
|
|
205
|
+
super().__setattr__("_exist_names", None)
|
|
206
|
+
super().__setattr__("_recompute_cell", None)
|
|
207
|
+
super().__setattr__("mixed_precision_type", None)
|
|
208
|
+
super().__setattr__("_lazy_construct_sig", None)
|
|
209
|
+
super().__setattr__("_jit_graph_name", '')
|
|
210
|
+
super().__setattr__("_compiled", False)
|
|
208
211
|
init_pipeline()
|
|
209
212
|
|
|
210
213
|
# call gc to release GE session resources used by non-used cell objects
|
|
@@ -214,38 +217,35 @@ class Cell(Cell_):
|
|
|
214
217
|
|
|
215
218
|
if flags:
|
|
216
219
|
self.add_flags(**flags)
|
|
217
|
-
|
|
220
|
+
super().__setattr__("_bprop_debug", False)
|
|
218
221
|
|
|
219
222
|
# hook
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
self._init_flag = False
|
|
245
|
-
self._shard_fn = None
|
|
246
|
-
self.has_bprop = False
|
|
223
|
+
super().__setattr__("_lazy_forward_pre_hook", None)
|
|
224
|
+
super().__setattr__("_lazy_forward_hook", None)
|
|
225
|
+
super().__setattr__("_lazy_backward_pre_hook", None)
|
|
226
|
+
super().__setattr__("_lazy_backward_hook", None)
|
|
227
|
+
super().__setattr__("_lazy_forward_pre_hook_with_kwargs", None)
|
|
228
|
+
super().__setattr__("_lazy_forward_hook_with_kwargs", None)
|
|
229
|
+
super().__setattr__("_cell_backward_pre_hook", None)
|
|
230
|
+
super().__setattr__("_cell_backward_hook", None)
|
|
231
|
+
super().__setattr__("_is_recursion_hook", False)
|
|
232
|
+
|
|
233
|
+
super().__setattr__("cell_type", None)
|
|
234
|
+
super().__setattr__("_has_config_recompute", False)
|
|
235
|
+
super().__setattr__("_lazy_user_parameters", None)
|
|
236
|
+
super().__setattr__("_dynamic_shape_inputs", None)
|
|
237
|
+
super().__setattr__("_has_mutable_args_list", None)
|
|
238
|
+
super().__setattr__("_jit_config_dict", dict())
|
|
239
|
+
super().__setattr__("grad_ops_label", False)
|
|
240
|
+
super().__setattr__("_is_check_and_refresh", False)
|
|
241
|
+
super().__setattr__("_amp_level", "")
|
|
242
|
+
super().__setattr__("_init_flag", False)
|
|
243
|
+
super().__setattr__("_shard_fn", None)
|
|
244
|
+
super().__setattr__("_in_strategy", None)
|
|
245
|
+
super().__setattr__("_out_strategy", None)
|
|
246
|
+
super().__setattr__("has_bprop", False)
|
|
247
247
|
if hasattr(self, "bprop"):
|
|
248
|
-
|
|
248
|
+
super().__setattr__("has_bprop", True)
|
|
249
249
|
|
|
250
250
|
def __getstate__(self):
|
|
251
251
|
base = Cell_.__getstate__(self)
|
|
@@ -255,7 +255,6 @@ class Cell(Cell_):
|
|
|
255
255
|
base, dict_ = state
|
|
256
256
|
Cell_.__setstate__(self, base)
|
|
257
257
|
self.__dict__ = dict_
|
|
258
|
-
self._attr_synced = False
|
|
259
258
|
|
|
260
259
|
def __bool__(self):
|
|
261
260
|
return True
|
|
@@ -269,6 +268,112 @@ class Cell(Cell_):
|
|
|
269
268
|
def create_time(self):
|
|
270
269
|
return self._create_time
|
|
271
270
|
|
|
271
|
+
@property
|
|
272
|
+
def _non_persistent_buffers_set(self):
|
|
273
|
+
"""_non_persistent_buffers_set"""
|
|
274
|
+
if self._lazy_non_persistent_buffers_set is None:
|
|
275
|
+
super().__setattr__("_lazy_non_persistent_buffers_set", set())
|
|
276
|
+
return self._lazy_non_persistent_buffers_set
|
|
277
|
+
|
|
278
|
+
@property
|
|
279
|
+
def _state_dict_hooks(self):
|
|
280
|
+
"""_state_dict_hooks"""
|
|
281
|
+
if self._lazy_state_dict_hooks is None:
|
|
282
|
+
super().__setattr__("_lazy_state_dict_hooks", OrderedDict())
|
|
283
|
+
return self._lazy_state_dict_hooks
|
|
284
|
+
|
|
285
|
+
@property
|
|
286
|
+
def _state_dict_pre_hooks(self):
|
|
287
|
+
"""_state_dict_pre_hooks"""
|
|
288
|
+
if self._lazy_state_dict_pre_hooks is None:
|
|
289
|
+
super().__setattr__("_lazy_state_dict_pre_hooks", OrderedDict())
|
|
290
|
+
return self._lazy_state_dict_pre_hooks
|
|
291
|
+
|
|
292
|
+
@property
|
|
293
|
+
def _load_state_dict_pre_hooks(self):
|
|
294
|
+
"""_load_state_dict_pre_hooks"""
|
|
295
|
+
if self._lazy_load_state_dict_pre_hooks is None:
|
|
296
|
+
super().__setattr__("_lazy_load_state_dict_pre_hooks", OrderedDict())
|
|
297
|
+
return self._lazy_load_state_dict_pre_hooks
|
|
298
|
+
|
|
299
|
+
@property
|
|
300
|
+
def _load_state_dict_post_hooks(self):
|
|
301
|
+
"""_load_state_dict_post_hooks"""
|
|
302
|
+
if self._lazy_load_state_dict_post_hooks is None:
|
|
303
|
+
super().__setattr__("_lazy_load_state_dict_post_hooks", OrderedDict())
|
|
304
|
+
return self._lazy_load_state_dict_post_hooks
|
|
305
|
+
|
|
306
|
+
@property
|
|
307
|
+
def compile_cache(self):
|
|
308
|
+
"""compile_cache"""
|
|
309
|
+
if self._compile_cache is None:
|
|
310
|
+
super().__setattr__("_compile_cache", set())
|
|
311
|
+
return self._compile_cache
|
|
312
|
+
|
|
313
|
+
@property
|
|
314
|
+
def phase_cache(self):
|
|
315
|
+
"""phase_cache"""
|
|
316
|
+
if self._phase_cache is None:
|
|
317
|
+
super().__setattr__("_phase_cache", dict())
|
|
318
|
+
return self._phase_cache
|
|
319
|
+
|
|
320
|
+
@property
|
|
321
|
+
def _forward_pre_hook(self):
|
|
322
|
+
"""_forward_pre_hook"""
|
|
323
|
+
if self._lazy_forward_pre_hook is None:
|
|
324
|
+
super().__setattr__("_lazy_forward_pre_hook", OrderedDict())
|
|
325
|
+
return self._lazy_forward_pre_hook
|
|
326
|
+
|
|
327
|
+
@property
|
|
328
|
+
def _forward_hook(self):
|
|
329
|
+
"""_forward_hook"""
|
|
330
|
+
if self._lazy_forward_hook is None:
|
|
331
|
+
super().__setattr__("_lazy_forward_hook", OrderedDict())
|
|
332
|
+
return self._lazy_forward_hook
|
|
333
|
+
|
|
334
|
+
@property
|
|
335
|
+
def _backward_pre_hook(self):
|
|
336
|
+
"""_backward_pre_hook"""
|
|
337
|
+
if self._lazy_backward_pre_hook is None:
|
|
338
|
+
super().__setattr__("_lazy_backward_pre_hook", OrderedDict())
|
|
339
|
+
return self._lazy_backward_pre_hook
|
|
340
|
+
|
|
341
|
+
@property
|
|
342
|
+
def _backward_hook(self):
|
|
343
|
+
"""_backward_hook"""
|
|
344
|
+
if self._lazy_backward_hook is None:
|
|
345
|
+
super().__setattr__("_lazy_backward_hook", OrderedDict())
|
|
346
|
+
return self._lazy_backward_hook
|
|
347
|
+
|
|
348
|
+
@property
|
|
349
|
+
def _forward_pre_hook_with_kwargs(self):
|
|
350
|
+
"""_backward_hook"""
|
|
351
|
+
if self._lazy_forward_pre_hook_with_kwargs is None:
|
|
352
|
+
super().__setattr__("_lazy_forward_pre_hook_with_kwargs", OrderedDict())
|
|
353
|
+
return self._lazy_forward_pre_hook_with_kwargs
|
|
354
|
+
|
|
355
|
+
@property
|
|
356
|
+
def _forward_hook_with_kwargs(self):
|
|
357
|
+
"""_backward_hook"""
|
|
358
|
+
if self._lazy_forward_hook_with_kwargs is None:
|
|
359
|
+
super().__setattr__("_lazy_forward_hook_with_kwargs", OrderedDict())
|
|
360
|
+
return self._lazy_forward_hook_with_kwargs
|
|
361
|
+
|
|
362
|
+
@property
|
|
363
|
+
def _user_parameters(self):
|
|
364
|
+
"""_user_parameters"""
|
|
365
|
+
if self._lazy_user_parameters is None:
|
|
366
|
+
super().__setattr__("_lazy_user_parameters", [])
|
|
367
|
+
return self._lazy_user_parameters
|
|
368
|
+
|
|
369
|
+
@_user_parameters.setter
|
|
370
|
+
def _user_parameters(self, value):
|
|
371
|
+
"""_user_parameters"""
|
|
372
|
+
if not isinstance(value, list):
|
|
373
|
+
raise TypeError(f"For 'Cell', the property '_user_parameters' must be list type, "
|
|
374
|
+
f"but got type {type(value)}.")
|
|
375
|
+
self._lazy_user_parameters = value
|
|
376
|
+
|
|
272
377
|
@property
|
|
273
378
|
def cell_init_args(self):
|
|
274
379
|
return self._cell_init_args
|
|
@@ -279,15 +384,21 @@ class Cell(Cell_):
|
|
|
279
384
|
Get exist parameter names adding by tuple or list of parameter.
|
|
280
385
|
"""
|
|
281
386
|
if self._exist_names is None:
|
|
282
|
-
|
|
387
|
+
super().__setattr__("_exist_names", set(""))
|
|
283
388
|
return self._exist_names
|
|
284
389
|
|
|
285
390
|
@property
|
|
286
391
|
def exist_objs(self):
|
|
287
392
|
if self._exist_objs is None:
|
|
288
|
-
|
|
393
|
+
super().__setattr__("_exist_objs", set())
|
|
289
394
|
return self._exist_objs
|
|
290
395
|
|
|
396
|
+
@property
|
|
397
|
+
def _construct_sig(self):
|
|
398
|
+
if self._lazy_construct_sig is None:
|
|
399
|
+
super().__setattr__("_lazy_construct_sig", inspect.signature(self.construct))
|
|
400
|
+
return self._lazy_construct_sig
|
|
401
|
+
|
|
291
402
|
@property
|
|
292
403
|
def param_prefix(self):
|
|
293
404
|
"""
|
|
@@ -319,6 +430,13 @@ class Cell(Cell_):
|
|
|
319
430
|
"""
|
|
320
431
|
return self._bprop_debug
|
|
321
432
|
|
|
433
|
+
@property
|
|
434
|
+
def compiled(self):
|
|
435
|
+
"""
|
|
436
|
+
Get whether `Cell` is compiled in graph mode.
|
|
437
|
+
"""
|
|
438
|
+
return self._compiled
|
|
439
|
+
|
|
322
440
|
@bprop_debug.setter
|
|
323
441
|
def bprop_debug(self, value):
|
|
324
442
|
"""
|
|
@@ -381,6 +499,8 @@ class Cell(Cell_):
|
|
|
381
499
|
`parameter_layout_dict` represents the tensor layout of a parameter, which is inferred by shard strategy and
|
|
382
500
|
distributed operator information.
|
|
383
501
|
"""
|
|
502
|
+
if self._parameter_layout_dict is None:
|
|
503
|
+
super().__setattr__("_parameter_layout_dict", {})
|
|
384
504
|
return self._parameter_layout_dict
|
|
385
505
|
|
|
386
506
|
@property
|
|
@@ -396,6 +516,8 @@ class Cell(Cell_):
|
|
|
396
516
|
|
|
397
517
|
@property
|
|
398
518
|
def parallel_parameter_name_list(self):
|
|
519
|
+
if self._parallel_parameter_name_list is None:
|
|
520
|
+
super().__setattr__("_parallel_parameter_name_list", ())
|
|
399
521
|
return self._parallel_parameter_name_list
|
|
400
522
|
|
|
401
523
|
@parallel_parameter_name_list.setter
|
|
@@ -435,10 +557,23 @@ class Cell(Cell_):
|
|
|
435
557
|
|
|
436
558
|
@property
|
|
437
559
|
def pipeline_segment(self):
|
|
560
|
+
"""
|
|
561
|
+
`pipeline_segment` represents the pipeline segment of current Cell.
|
|
562
|
+
"""
|
|
438
563
|
return self._pipeline_segment
|
|
439
564
|
|
|
440
565
|
@pipeline_segment.setter
|
|
441
566
|
def pipeline_segment(self, value):
|
|
567
|
+
"""
|
|
568
|
+
Set the `pipeline_segment` of a Cell. Only effective in zero_bubble_v scheduler.
|
|
569
|
+
|
|
570
|
+
Args:
|
|
571
|
+
value (int): The pipeline segment of a parameter.
|
|
572
|
+
|
|
573
|
+
Raises:
|
|
574
|
+
TypeError: If `value` is not int type or is a bool type.
|
|
575
|
+
ValueError: If `value` is not a positive integer.
|
|
576
|
+
"""
|
|
442
577
|
if not isinstance(value, int) or isinstance(value, bool):
|
|
443
578
|
raise TypeError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' "
|
|
444
579
|
"must be int type, but got type : {}".format(type(value)))
|
|
@@ -450,6 +585,8 @@ class Cell(Cell_):
|
|
|
450
585
|
|
|
451
586
|
@property
|
|
452
587
|
def parallel_parameter_merge_net_dict(self):
|
|
588
|
+
if self._parallel_parameter_merge_net_dict is None:
|
|
589
|
+
super().__setattr__("_parallel_parameter_merge_net_dict", {})
|
|
453
590
|
return self._parallel_parameter_merge_net_dict
|
|
454
591
|
|
|
455
592
|
@parallel_parameter_merge_net_dict.setter
|
|
@@ -867,6 +1004,7 @@ class Cell(Cell_):
|
|
|
867
1004
|
if hasattr(self, "compile_cache") and self.compile_cache:
|
|
868
1005
|
_cell_graph_executor.del_net_res(self, self.compile_cache)
|
|
869
1006
|
Cell.total_instance_count -= 1
|
|
1007
|
+
Cell.global_cells.pop(self, None)
|
|
870
1008
|
|
|
871
1009
|
def __delattr__(self, name):
|
|
872
1010
|
if name in self._params:
|
|
@@ -879,47 +1017,15 @@ class Cell(Cell_):
|
|
|
879
1017
|
del self._params_list[name]
|
|
880
1018
|
else:
|
|
881
1019
|
object.__delattr__(self, name)
|
|
882
|
-
self._attr_synced = False
|
|
883
|
-
|
|
884
|
-
def _cast_mixed_precision_inputs(self, inputs, dst_type):
|
|
885
|
-
"""Cast input for mixed precision"""
|
|
886
|
-
res = list()
|
|
887
|
-
for item in inputs:
|
|
888
|
-
if isinstance(item, tuple):
|
|
889
|
-
res.append(self._cast_mixed_precision_inputs(item, dst_type))
|
|
890
|
-
elif isinstance(item, float):
|
|
891
|
-
res.append(self.cast(item, dst_type))
|
|
892
|
-
elif hasattr(item, "dtype") and item.dtype in \
|
|
893
|
-
{mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16} and item.dtype != dst_type:
|
|
894
|
-
res.append(self.cast(item, dst_type))
|
|
895
|
-
else:
|
|
896
|
-
res.append(item)
|
|
897
|
-
return tuple(res)
|
|
898
1020
|
|
|
899
1021
|
def cast_inputs(self, inputs, dst_type):
|
|
900
1022
|
"""
|
|
901
1023
|
Cast inputs to specified type.
|
|
902
1024
|
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
dst_type (mindspore.dtype): The specified data type.
|
|
906
|
-
|
|
907
|
-
returns:
|
|
908
|
-
tuple[Tensor], the result with destination data type.
|
|
1025
|
+
.. warning::
|
|
1026
|
+
This interface will be deprecated in future versions.
|
|
909
1027
|
"""
|
|
910
|
-
|
|
911
|
-
for item in inputs:
|
|
912
|
-
if isinstance(item, tuple):
|
|
913
|
-
res.append(self.cast_inputs(item, dst_type))
|
|
914
|
-
else:
|
|
915
|
-
res.append(self.cast(item, dst_type))
|
|
916
|
-
return tuple(res)
|
|
917
|
-
|
|
918
|
-
def _do_parameter_broadcast(self):
|
|
919
|
-
if context.get_auto_parallel_context("parallel_mode") == ParallelMode.DATA_PARALLEL:
|
|
920
|
-
if not self.parameter_broadcast_done:
|
|
921
|
-
_pynative_executor.parameter_broadcast(self, self.phase)
|
|
922
|
-
self.parameter_broadcast_done = True
|
|
1028
|
+
logger.warning(f"'cast_inputs' will be deprecated in future versions.")
|
|
923
1029
|
|
|
924
1030
|
def run_construct(self, cast_inputs, kwargs):
|
|
925
1031
|
"""
|
|
@@ -940,30 +1046,34 @@ class Cell(Cell_):
|
|
|
940
1046
|
output = self._run_construct(cast_inputs, kwargs)
|
|
941
1047
|
return output
|
|
942
1048
|
|
|
943
|
-
def _run_construct(self, *
|
|
1049
|
+
def _run_construct(self, *args, **kwargs):
|
|
944
1050
|
"""Run the construct function"""
|
|
945
1051
|
if self._forward_pre_hook:
|
|
946
|
-
|
|
1052
|
+
args, kwargs = self._run_forward_pre_hook(args, kwargs)
|
|
1053
|
+
|
|
1054
|
+
if self._backward_hook:
|
|
1055
|
+
args = self._cell_backward_hook(args)
|
|
947
1056
|
|
|
948
1057
|
if self._shard_fn is not None:
|
|
949
|
-
output = self._shard_fn(*
|
|
1058
|
+
output = self._shard_fn(*args, **kwargs)
|
|
950
1059
|
elif _pynative_executor.requires_grad():
|
|
951
|
-
if self.
|
|
952
|
-
output = self.
|
|
953
|
-
elif self._recompute_cell is not None:
|
|
954
|
-
output = self._recompute_cell(*inputs, **kwargs)
|
|
1060
|
+
if self._recompute_cell is not None:
|
|
1061
|
+
output = self._recompute_cell(*args, **kwargs)
|
|
955
1062
|
elif self.has_bprop:
|
|
956
|
-
output = self._call_custom_bprop(*
|
|
1063
|
+
output = self._call_custom_bprop(*args, **kwargs)
|
|
957
1064
|
else:
|
|
958
|
-
output = self.construct(*
|
|
1065
|
+
output = self.construct(*args, **kwargs)
|
|
959
1066
|
else:
|
|
960
|
-
output = self.construct(*
|
|
1067
|
+
output = self.construct(*args, **kwargs)
|
|
961
1068
|
|
|
962
1069
|
if self._forward_hook:
|
|
963
|
-
output = self._run_forward_hook(
|
|
1070
|
+
output = self._run_forward_hook(args, kwargs, output)
|
|
1071
|
+
|
|
1072
|
+
if self._backward_hook:
|
|
1073
|
+
output = self._cell_backward_hook(output)
|
|
964
1074
|
|
|
965
1075
|
if self._backward_pre_hook:
|
|
966
|
-
output = self.
|
|
1076
|
+
output = self._cell_backward_pre_hook(output)
|
|
967
1077
|
|
|
968
1078
|
return output
|
|
969
1079
|
|
|
@@ -998,22 +1108,6 @@ class Cell(Cell_):
|
|
|
998
1108
|
f"{default_args} default argument, total {positional_args + default_args}, "
|
|
999
1109
|
f"but got {len(args)}.")
|
|
1000
1110
|
|
|
1001
|
-
def _hook_fn_registered(self):
|
|
1002
|
-
'''Hook function in graph mode'''
|
|
1003
|
-
# Check super().__init__() in graph mode.
|
|
1004
|
-
try:
|
|
1005
|
-
if self._forward_pre_hook or self._forward_hook or self._backward_pre_hook or self._backward_hook:
|
|
1006
|
-
return True
|
|
1007
|
-
except AttributeError as e:
|
|
1008
|
-
raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
|
|
1009
|
-
f"Please use 'super().__init__()'.") from e
|
|
1010
|
-
if not self._is_recursion_hook:
|
|
1011
|
-
self._is_recursion_hook = True
|
|
1012
|
-
for cell in self.cells():
|
|
1013
|
-
if cell._hook_fn_registered():
|
|
1014
|
-
return True
|
|
1015
|
-
return False
|
|
1016
|
-
|
|
1017
1111
|
def _get_prims_recursively(self):
|
|
1018
1112
|
all_prims = list()
|
|
1019
1113
|
for _, value in self._primitives.items():
|
|
@@ -1039,9 +1133,6 @@ class Cell(Cell_):
|
|
|
1039
1133
|
>>> net = nn.Dense(3, 4)
|
|
1040
1134
|
>>> net.set_data_parallel()
|
|
1041
1135
|
"""
|
|
1042
|
-
if context._get_mode() == context.PYNATIVE_MODE:
|
|
1043
|
-
raise ValueError("set_data_parallel: does not support PyNative mode.")
|
|
1044
|
-
|
|
1045
1136
|
all_prims = self._get_prims_recursively()
|
|
1046
1137
|
for prim in all_prims:
|
|
1047
1138
|
prim.add_prim_attr("strategy_gen_mode", "data_parallel")
|
|
@@ -1120,8 +1211,6 @@ class Cell(Cell_):
|
|
|
1120
1211
|
... out = self.blocks[i](out)
|
|
1121
1212
|
... return out
|
|
1122
1213
|
"""
|
|
1123
|
-
if context._get_mode() == context.PYNATIVE_MODE:
|
|
1124
|
-
raise ValueError("The Cell offload does not support PyNative mode now.")
|
|
1125
1214
|
if isinstance(backward_prefetch, str):
|
|
1126
1215
|
Validator.check_string(backward_prefetch, ['Auto'], 'backward_prefetch', self.cls_name)
|
|
1127
1216
|
else:
|
|
@@ -1129,11 +1218,10 @@ class Cell(Cell_):
|
|
|
1129
1218
|
for prim in self._get_prims_recursively():
|
|
1130
1219
|
prim._offload(backward_prefetch=backward_prefetch)
|
|
1131
1220
|
|
|
1132
|
-
def shard(self, in_strategy, out_strategy=None, parameter_plan=None
|
|
1221
|
+
def shard(self, in_strategy, out_strategy=None, parameter_plan=None):
|
|
1133
1222
|
"""
|
|
1134
1223
|
Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be
|
|
1135
|
-
generated by sharding propagation. In
|
|
1136
|
-
execution in graph mode. In Graph mode, use this method to specify distribution strategy for a Cell,
|
|
1224
|
+
generated by sharding propagation. In Graph mode, use this method to specify distribution strategy for a Cell,
|
|
1137
1225
|
strategy for others will be set by sharding propagation.
|
|
1138
1226
|
in_strategy and out_strategy define the input and output layout respectively.
|
|
1139
1227
|
in_strategy/out_strategy should be a tuple, each element of which corresponds to the desired layout of
|
|
@@ -1145,11 +1233,14 @@ class Cell(Cell_):
|
|
|
1145
1233
|
In other parallel modes, strategies set here will be ignored.
|
|
1146
1234
|
- If the input contain Parameter, its strategy should be set in `in_strategy`.
|
|
1147
1235
|
|
|
1236
|
+
.. warning::
|
|
1237
|
+
The method is currently not supported in PyNative mode.
|
|
1238
|
+
|
|
1148
1239
|
Args:
|
|
1149
1240
|
in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple. Tuple
|
|
1150
1241
|
defines the layout of the corresponding input.
|
|
1151
1242
|
out_strategy (Union[None, tuple]): Define the layout of outputs similar with in_strategy.
|
|
1152
|
-
|
|
1243
|
+
Default: ``None`` .
|
|
1153
1244
|
parameter_plan (Union[dict, None]): Define the layout for the specified parameters. Each element in dict
|
|
1154
1245
|
defines the layout of the parameter like "param_name: layout".
|
|
1155
1246
|
The key is a parameter name of type 'str'.
|
|
@@ -1157,14 +1248,6 @@ class Cell(Cell_):
|
|
|
1157
1248
|
If the parameter name is incorrect or the corresponding parameter
|
|
1158
1249
|
has been set, the parameter setting will be ignored.
|
|
1159
1250
|
Default: ``None`` .
|
|
1160
|
-
device (str): Select a certain device target. It is not in use right now.
|
|
1161
|
-
Support [ ``"CPU"`` , ``"GPU"`` , ``"Ascend"`` ]. Default: ``"Ascend"`` .
|
|
1162
|
-
level (int): Option for parallel strategy infer algorithm, namely the object function, maximize computation
|
|
1163
|
-
over communication ratio, maximize speed performance, minimize memory usage etc. It is not in
|
|
1164
|
-
use right now. Support [ ``"0"`` , ``"1"`` , ``"2"`` ]. Default: ``0`` .
|
|
1165
|
-
|
|
1166
|
-
Returns:
|
|
1167
|
-
Function, return the cell construct function that will be executed under auto parallel process.
|
|
1168
1251
|
|
|
1169
1252
|
Examples:
|
|
1170
1253
|
>>> import mindspore.nn as nn
|
|
@@ -1182,40 +1265,34 @@ class Cell(Cell_):
|
|
|
1182
1265
|
... def __init__(self):
|
|
1183
1266
|
... self.block1 = Block()
|
|
1184
1267
|
... self.block2 = Block()
|
|
1185
|
-
... self.
|
|
1186
|
-
... parameter_plan={'self.block2.shard.dense1.weight': (4, 1)})
|
|
1268
|
+
... self.block2.shard(in_strategy=((2, 1),), parameter_plan={'self.block2.dense1.weight': (4, 1)})
|
|
1187
1269
|
... def construct(self, x):
|
|
1188
1270
|
... x = self.block1(x)
|
|
1189
|
-
... x = self.
|
|
1271
|
+
... x = self.block2(x)
|
|
1190
1272
|
... return x
|
|
1191
1273
|
"""
|
|
1192
1274
|
if ms.communication.management.get_group_size() == 1:
|
|
1193
|
-
return
|
|
1194
|
-
shard_fn = Shard()
|
|
1195
|
-
fn = shard_fn(self, in_strategy, out_strategy, parameter_plan, device, level)
|
|
1196
|
-
self._shard_fn = fn
|
|
1197
|
-
return fn
|
|
1198
|
-
|
|
1199
|
-
def auto_cast_inputs(self, inputs):
|
|
1200
|
-
"""
|
|
1201
|
-
Auto cast inputs in mixed precision scenarios.
|
|
1202
|
-
|
|
1203
|
-
Args:
|
|
1204
|
-
inputs (tuple): the inputs of construct.
|
|
1205
|
-
|
|
1206
|
-
Returns:
|
|
1207
|
-
Tuple, the inputs after data type cast.
|
|
1208
|
-
"""
|
|
1209
|
-
msg = f"'auto_cast_inputs' is deprecated from version 2.0 and will be removed in a future version."
|
|
1210
|
-
logger.warning(msg)
|
|
1211
|
-
cast_inputs = inputs
|
|
1212
|
-
mixed_type = self.get_mixed_precision_type()
|
|
1213
|
-
if mixed_type == MixedPrecisionType.FP16:
|
|
1214
|
-
cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float16)
|
|
1215
|
-
if mixed_type == MixedPrecisionType.FP32:
|
|
1216
|
-
cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float32)
|
|
1275
|
+
return
|
|
1217
1276
|
|
|
1218
|
-
|
|
1277
|
+
shard_fn = Shard()
|
|
1278
|
+
self._shard_fn = shard_fn(self, in_strategy, out_strategy, parameter_plan)
|
|
1279
|
+
|
|
1280
|
+
if self._in_strategy is not None: # pylint: disable=E0203
|
|
1281
|
+
msg = (
|
|
1282
|
+
"For '%s', 'Shard' has been configured more than once. "
|
|
1283
|
+
"The existing in_strategy is %s and the existing out_strategy is %s. "
|
|
1284
|
+
"The new in_strategy %s and out_strategy %s may not take effect. "
|
|
1285
|
+
"It is recommended to configure 'Shard' only once."
|
|
1286
|
+
) % (
|
|
1287
|
+
self._cell_tag,
|
|
1288
|
+
self._in_strategy, # pylint: disable=E0203
|
|
1289
|
+
self._out_strategy, # pylint: disable=E0203
|
|
1290
|
+
shard_fn.in_strategy,
|
|
1291
|
+
shard_fn.out_strategy,
|
|
1292
|
+
)
|
|
1293
|
+
logger.warning(msg)
|
|
1294
|
+
self._in_strategy = shard_fn.in_strategy
|
|
1295
|
+
self._out_strategy = shard_fn.out_strategy
|
|
1219
1296
|
|
|
1220
1297
|
def _init_check(self):
|
|
1221
1298
|
for param in self.get_parameters(expand=False):
|
|
@@ -1224,15 +1301,25 @@ class Cell(Cell_):
|
|
|
1224
1301
|
self._init_flag = True
|
|
1225
1302
|
|
|
1226
1303
|
def _self_check(self):
|
|
1227
|
-
|
|
1228
|
-
self.
|
|
1229
|
-
|
|
1304
|
+
try:
|
|
1305
|
+
if not self._is_check_and_refresh: # pylint: disable=E0203
|
|
1306
|
+
self.check_names_and_refresh_name()
|
|
1307
|
+
self._is_check_and_refresh = True
|
|
1308
|
+
except AttributeError as e:
|
|
1309
|
+
raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
|
|
1310
|
+
f"Please use 'super().__init__()'.") from e
|
|
1230
1311
|
|
|
1231
1312
|
def _predict(self, *args, **kwargs):
|
|
1313
|
+
'''Graph executor for predict'''
|
|
1232
1314
|
if not hasattr(self, "phase"):
|
|
1233
1315
|
return False, None
|
|
1234
1316
|
if (self.phase == "prefill" or self.phase == 'increment') and self.phase in self.phase_cache:
|
|
1235
|
-
new_args =
|
|
1317
|
+
new_args = _get_args_for_run(self, args, kwargs, self._has_mutable_args_list, True)
|
|
1318
|
+
if self.jit_config_dict:
|
|
1319
|
+
jit_config_dict = self.jit_config_dict
|
|
1320
|
+
else:
|
|
1321
|
+
jit_config_dict = JitConfig().jit_config_dict
|
|
1322
|
+
_cell_graph_executor._graph_executor.set_jit_config(jit_config_dict)
|
|
1236
1323
|
res = _cell_graph_executor._graph_executor(tuple(new_args), self.phase_cache[self.phase])
|
|
1237
1324
|
res = _convert_python_data(res)
|
|
1238
1325
|
return True, res
|
|
@@ -1241,8 +1328,9 @@ class Cell(Cell_):
|
|
|
1241
1328
|
def __call__(self, *args, **kwargs):
|
|
1242
1329
|
# Run in Graph mode.
|
|
1243
1330
|
if context._get_mode() == context.GRAPH_MODE and os.getenv("MS_JIT") != '0':
|
|
1331
|
+
self._compiled = True
|
|
1244
1332
|
if kwargs:
|
|
1245
|
-
bound_arguments = self.
|
|
1333
|
+
bound_arguments = self._construct_sig.bind(*args, **kwargs)
|
|
1246
1334
|
bound_arguments.apply_defaults()
|
|
1247
1335
|
args = bound_arguments.args
|
|
1248
1336
|
kwargs = bound_arguments.kwargs
|
|
@@ -1251,11 +1339,8 @@ class Cell(Cell_):
|
|
|
1251
1339
|
if predict_compiled:
|
|
1252
1340
|
return res
|
|
1253
1341
|
self._check_construct_args(*args)
|
|
1254
|
-
|
|
1255
|
-
if self._hook_fn_registered():
|
|
1256
|
-
logger.warning(f"For 'Cell', it's not support hook function in graph mode. If you want to use hook "
|
|
1257
|
-
f"function, please use context.set_context to set pynative mode.")
|
|
1258
1342
|
self._self_check()
|
|
1343
|
+
self.__compile_cell_hook__ = True
|
|
1259
1344
|
out = self.compile_and_run(*args, **kwargs)
|
|
1260
1345
|
return out
|
|
1261
1346
|
|
|
@@ -1324,37 +1409,12 @@ class Cell(Cell_):
|
|
|
1324
1409
|
"""
|
|
1325
1410
|
with _no_grad():
|
|
1326
1411
|
output = self.construct(*args, **kwargs)
|
|
1327
|
-
_pynative_executor.call_custom_bprop(self, output, *args, **kwargs)
|
|
1328
|
-
return output
|
|
1412
|
+
return _pynative_executor.call_custom_bprop(self, output, *args, **kwargs)
|
|
1329
1413
|
|
|
1330
1414
|
def _add_attr(self, name, value):
|
|
1331
1415
|
if name and name[:2] != '__' and name not in Cell.IGNORE_LIST:
|
|
1332
1416
|
super(Cell, self)._add_attr(name, value)
|
|
1333
1417
|
|
|
1334
|
-
def _sync_attr_for_compile(self):
|
|
1335
|
-
"""Sync the attr to c++ object."""
|
|
1336
|
-
if self._attr_synced:
|
|
1337
|
-
return
|
|
1338
|
-
cells = self.__dict__.get('_cells')
|
|
1339
|
-
for key in cells:
|
|
1340
|
-
cell = cells[key]
|
|
1341
|
-
cell._sync_attr_for_compile()
|
|
1342
|
-
self._add_attr(key, cell)
|
|
1343
|
-
params = self.__dict__.get('_params')
|
|
1344
|
-
for key in params:
|
|
1345
|
-
if '.' in key:
|
|
1346
|
-
continue
|
|
1347
|
-
param = params[key]
|
|
1348
|
-
self._add_attr(key, param)
|
|
1349
|
-
params_list = self.__dict__.get('_params_list')
|
|
1350
|
-
for key in params_list:
|
|
1351
|
-
params_list_item = params_list[key]
|
|
1352
|
-
self._add_attr(key, params_list_item)
|
|
1353
|
-
for key in self.__dict__:
|
|
1354
|
-
value = self.__dict__[key]
|
|
1355
|
-
self._add_attr(key, value)
|
|
1356
|
-
self._attr_synced = True
|
|
1357
|
-
|
|
1358
1418
|
def _set_attr_for_param_or_param_tuple(self, name, value):
|
|
1359
1419
|
"""Set attr for param and tensor."""
|
|
1360
1420
|
if isinstance(value, Parameter):
|
|
@@ -1369,27 +1429,16 @@ class Cell(Cell_):
|
|
|
1369
1429
|
# If there are multiple identical objects, their names only check once.
|
|
1370
1430
|
continue
|
|
1371
1431
|
exist_objs.add(item)
|
|
1372
|
-
if item.name
|
|
1373
|
-
|
|
1374
|
-
"Please set a unique name for the parameter in ParameterTuple '{}'.".format(value))
|
|
1375
|
-
item.name = item.name + "$" + str(self._id)
|
|
1432
|
+
if _is_parameter_generated(item.name):
|
|
1433
|
+
item.name = "Parameter$" + str(self._id)
|
|
1376
1434
|
self._id += 1
|
|
1377
|
-
self.insert_param_to_cell(item.name, item, check_name_contain_dot=False)
|
|
1378
1435
|
if item.name in exist_names:
|
|
1379
1436
|
raise ValueError("The value {} , its name '{}' already exists. "
|
|
1380
1437
|
"Please set a unique name for the parameter.".format(value, item.name))
|
|
1381
1438
|
exist_names.add(item.name)
|
|
1439
|
+
self.insert_param_to_cell(item.name, item, check_name_contain_dot=False)
|
|
1382
1440
|
|
|
1383
|
-
|
|
1384
|
-
if name in self.__dict__:
|
|
1385
|
-
del self.__dict__[name]
|
|
1386
|
-
params = self.__dict__.get('_params')
|
|
1387
|
-
if name in params:
|
|
1388
|
-
del params[name]
|
|
1389
|
-
params_list = self.__dict__.get('_params_list')
|
|
1390
|
-
params_list[name] = value
|
|
1391
|
-
else:
|
|
1392
|
-
object.__setattr__(self, name, value)
|
|
1441
|
+
object.__setattr__(self, name, value)
|
|
1393
1442
|
|
|
1394
1443
|
def _set_attr_for_parameter_in_list_or_tuple(self, name, value):
|
|
1395
1444
|
"""Set attr for parameter in list or tuple."""
|
|
@@ -1398,9 +1447,6 @@ class Cell(Cell_):
|
|
|
1398
1447
|
# If there are multiple identical objects, their names only check once.
|
|
1399
1448
|
continue
|
|
1400
1449
|
self.exist_objs.add(item)
|
|
1401
|
-
if item.name == PARAMETER_NAME_DEFAULT:
|
|
1402
|
-
item.name = item.name + "$" + str(self._id)
|
|
1403
|
-
self._id += 1
|
|
1404
1450
|
if item.name in self.exist_names:
|
|
1405
1451
|
raise ValueError(f"The value {value} , its name '{item.name}' already exists. "
|
|
1406
1452
|
"Please set a unique name for the parameter.")
|
|
@@ -1513,24 +1559,6 @@ class Cell(Cell_):
|
|
|
1513
1559
|
main_str += ")"
|
|
1514
1560
|
return main_str
|
|
1515
1561
|
|
|
1516
|
-
def load_parameter_slice(self, params):
|
|
1517
|
-
"""
|
|
1518
|
-
Replace parameters with sliced tensors by parallel strategies.
|
|
1519
|
-
|
|
1520
|
-
Note:
|
|
1521
|
-
This interface is deprecated.
|
|
1522
|
-
"""
|
|
1523
|
-
logger.warning("'load_parameter_slice' function is deprecated.")
|
|
1524
|
-
|
|
1525
|
-
def set_parallel_input_with_inputs(self, *inputs):
|
|
1526
|
-
"""
|
|
1527
|
-
Slice inputs tensors by parallel strategies.
|
|
1528
|
-
|
|
1529
|
-
Note:
|
|
1530
|
-
This interface is deprecated.
|
|
1531
|
-
"""
|
|
1532
|
-
logger.warning("'set_parallel_input_with_inputs' function is deprecated.")
|
|
1533
|
-
|
|
1534
1562
|
def set_inputs(self, *inputs, **kwargs):
|
|
1535
1563
|
"""
|
|
1536
1564
|
Save set inputs for computation graph. The number of inputs should be the same with that of the datasets. When
|
|
@@ -1589,8 +1617,6 @@ class Cell(Cell_):
|
|
|
1589
1617
|
_pynative_executor.set_dynamic_input(self, *self._dynamic_shape_inputs)
|
|
1590
1618
|
else:
|
|
1591
1619
|
self._check_construct_args(*inputs)
|
|
1592
|
-
# TODO(tronzhang): It may error for no actually args here. So just set in fullmode,
|
|
1593
|
-
# which means that incremental mode is lacking dynamic input.
|
|
1594
1620
|
else:
|
|
1595
1621
|
self._dynamic_shape_inputs = _process_dyn_args(self.construct, kwargs)
|
|
1596
1622
|
|
|
@@ -1665,7 +1691,6 @@ class Cell(Cell_):
|
|
|
1665
1691
|
_cell_graph_executor._graph_executor.check_argument_consistency(compile_args, args, "set_inputs")
|
|
1666
1692
|
self._check_parameter_consistency(compile_args, args)
|
|
1667
1693
|
Validator.check_symbolic_shape(compile_args, args)
|
|
1668
|
-
self.saved_dynamic_shape = compile_args
|
|
1669
1694
|
return compile_args
|
|
1670
1695
|
return args
|
|
1671
1696
|
|
|
@@ -1678,8 +1703,9 @@ class Cell(Cell_):
|
|
|
1678
1703
|
kwargs (dict): Kwargs of the Cell object.
|
|
1679
1704
|
"""
|
|
1680
1705
|
_init_auto_parallel_context(self)
|
|
1681
|
-
|
|
1682
|
-
|
|
1706
|
+
compile_args = self._get_compile_args(args)
|
|
1707
|
+
self._has_mutable_args_list = _get_mutable_flags(compile_args)
|
|
1708
|
+
_cell_graph_executor.compile(self, *compile_args, phase=self.phase,
|
|
1683
1709
|
jit_config_dict=self._jit_config_dict, **kwargs)
|
|
1684
1710
|
_clear_auto_parallel_context(self)
|
|
1685
1711
|
|
|
@@ -1698,25 +1724,14 @@ class Cell(Cell_):
|
|
|
1698
1724
|
Object, the result of executing.
|
|
1699
1725
|
"""
|
|
1700
1726
|
self.compile(*args, **kwargs)
|
|
1701
|
-
self.
|
|
1702
|
-
|
|
1727
|
+
new_args = _get_args_for_run(self, args, kwargs, self._has_mutable_args_list, False)
|
|
1728
|
+
if self.jit_config_dict:
|
|
1729
|
+
jit_config_dict = self.jit_config_dict
|
|
1730
|
+
else:
|
|
1731
|
+
jit_config_dict = JitConfig().jit_config_dict
|
|
1732
|
+
_cell_graph_executor._graph_executor.set_jit_config(jit_config_dict)
|
|
1703
1733
|
return _cell_graph_executor(self, *new_args, phase=self.phase)
|
|
1704
1734
|
|
|
1705
|
-
def auto_parallel_compile_and_run(self):
|
|
1706
|
-
"""
|
|
1707
|
-
Whether or not to execute compile and run in 'AUTO_PARALLEL' or 'SEMI_AUTO_PARALLEL' mode.
|
|
1708
|
-
|
|
1709
|
-
Note:
|
|
1710
|
-
This interface is deprecated.
|
|
1711
|
-
"""
|
|
1712
|
-
logger.warning("'auto_parallel_compile_and_run' function is deprecated.")
|
|
1713
|
-
|
|
1714
|
-
def exec_checkpoint_graph(self):
|
|
1715
|
-
"""Executes GE saving checkpoint graph operation."""
|
|
1716
|
-
logger.warning("'exec_checkpoint_graph' function is deprecated.")
|
|
1717
|
-
self.add_flags(ge_sync_data=True)
|
|
1718
|
-
_cell_graph_executor(self, phase='save')
|
|
1719
|
-
|
|
1720
1735
|
def insert_param_to_cell(self, param_name, param, check_name_contain_dot=True):
|
|
1721
1736
|
"""
|
|
1722
1737
|
Adds a parameter to the current cell.
|
|
@@ -1762,35 +1777,10 @@ class Cell(Cell_):
|
|
|
1762
1777
|
if not isinstance(param, Parameter) and param is not None:
|
|
1763
1778
|
raise TypeError(f"For 'insert_param_to_cell', the argument 'param' must be 'Parameter' if not None, "
|
|
1764
1779
|
f"but got {type(param)}.")
|
|
1765
|
-
if isinstance(param, Parameter) and param.name
|
|
1780
|
+
if isinstance(param, Parameter) and _is_parameter_generated(param.name):
|
|
1766
1781
|
param.name = param_name
|
|
1767
1782
|
self._params[param_name] = param
|
|
1768
1783
|
|
|
1769
|
-
def cast_param(self, param):
|
|
1770
|
-
"""
|
|
1771
|
-
Cast parameter according to auto mix precision level in pynative mode.
|
|
1772
|
-
|
|
1773
|
-
This interface is currently used in the case of auto mix precision and usually needs not to be used explicitly.
|
|
1774
|
-
|
|
1775
|
-
Args:
|
|
1776
|
-
param (Parameter): Parameters, the type of which should be cast.
|
|
1777
|
-
|
|
1778
|
-
Returns:
|
|
1779
|
-
Parameter, the input parameter with type automatically cast.
|
|
1780
|
-
"""
|
|
1781
|
-
msg = f"'cast_param' is deprecated from version 2.0 and will be removed in a future version."
|
|
1782
|
-
logger.warning(msg)
|
|
1783
|
-
mixed_type = self.get_mixed_precision_type()
|
|
1784
|
-
if mixed_type != MixedPrecisionType.NOTSET:
|
|
1785
|
-
if mixed_type == MixedPrecisionType.FP32:
|
|
1786
|
-
param.set_cast_dtype(mstype.float32)
|
|
1787
|
-
elif mixed_type == MixedPrecisionType.FP16:
|
|
1788
|
-
param.set_cast_dtype(mstype.float16)
|
|
1789
|
-
elif hasattr(param, "set_cast_dtype"):
|
|
1790
|
-
# retest dtype
|
|
1791
|
-
param.set_cast_dtype()
|
|
1792
|
-
return param
|
|
1793
|
-
|
|
1794
1784
|
def insert_child_to_cell(self, child_name, child_cell):
|
|
1795
1785
|
"""
|
|
1796
1786
|
Adds a child cell to the current cell with a given name.
|
|
@@ -1850,27 +1840,10 @@ class Cell(Cell_):
|
|
|
1850
1840
|
"""
|
|
1851
1841
|
Remove the redundant parameters.
|
|
1852
1842
|
|
|
1853
|
-
|
|
1843
|
+
.. warning::
|
|
1844
|
+
This interface will be deprecated in future versions.
|
|
1854
1845
|
"""
|
|
1855
|
-
|
|
1856
|
-
for _, cell in cells:
|
|
1857
|
-
params = cell._params.items()
|
|
1858
|
-
for param_name, param in list(params):
|
|
1859
|
-
if param.name not in self.parallel_parameter_name_list:
|
|
1860
|
-
cell._params.pop(param_name)
|
|
1861
|
-
logger.info("remove the redundant parameter: %s", param.name)
|
|
1862
|
-
continue
|
|
1863
|
-
cell_dict = cell.__dict__
|
|
1864
|
-
for key in cell_dict:
|
|
1865
|
-
if isinstance(cell_dict[key], ParameterTuple):
|
|
1866
|
-
param_tuple = cell_dict[key]
|
|
1867
|
-
new_param_tuple = []
|
|
1868
|
-
for param in param_tuple:
|
|
1869
|
-
if param.name not in self.parallel_parameter_name_list:
|
|
1870
|
-
logger.info("remove the redundant parameter: %s in ParameterTuple", param.name)
|
|
1871
|
-
continue
|
|
1872
|
-
new_param_tuple.append(param)
|
|
1873
|
-
cell.__dict__[key] = ParameterTuple(new_param_tuple)
|
|
1846
|
+
logger.warning(f"'remove_redundant_parameters' will be deprecated in future versions.")
|
|
1874
1847
|
|
|
1875
1848
|
def _get_cell_parallel_mode(self):
|
|
1876
1849
|
"""Determine whether the current cell is in parallel mode."""
|
|
@@ -1926,16 +1899,13 @@ class Cell(Cell_):
|
|
|
1926
1899
|
# replace all original usage.
|
|
1927
1900
|
cells = self.cells_and_names()
|
|
1928
1901
|
is_parallel_mode = self._get_cell_parallel_mode()
|
|
1929
|
-
is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
|
|
1930
1902
|
|
|
1931
1903
|
for _, cell in cells:
|
|
1932
1904
|
params = cell._params.items()
|
|
1933
1905
|
for param_name, param in params:
|
|
1934
|
-
not_sliced = not param.sliced
|
|
1935
|
-
judgment = not_sliced
|
|
1936
1906
|
if param.param_info.is_pipeline_shared_param:
|
|
1937
1907
|
continue
|
|
1938
|
-
if
|
|
1908
|
+
if is_parallel_mode and not param.sliced:
|
|
1939
1909
|
continue
|
|
1940
1910
|
if not auto_parallel_mode:
|
|
1941
1911
|
cell._params[param_name] = _updata(param)
|
|
@@ -1948,11 +1918,9 @@ class Cell(Cell_):
|
|
|
1948
1918
|
param_tuple = cell_dict[key]
|
|
1949
1919
|
new_param_tuple = []
|
|
1950
1920
|
for param in param_tuple:
|
|
1951
|
-
not_sliced = not param.sliced
|
|
1952
|
-
judgment = not_sliced
|
|
1953
1921
|
if param.param_info.is_pipeline_shared_param:
|
|
1954
1922
|
continue
|
|
1955
|
-
if
|
|
1923
|
+
if is_parallel_mode and not param.sliced:
|
|
1956
1924
|
continue
|
|
1957
1925
|
if not auto_parallel_mode:
|
|
1958
1926
|
new_param_tuple.append(_updata(param))
|
|
@@ -2591,15 +2559,6 @@ class Cell(Cell_):
|
|
|
2591
2559
|
self.add_flags_recursive(broadcast_flag=mode)
|
|
2592
2560
|
return self
|
|
2593
2561
|
|
|
2594
|
-
def set_auto_parallel(self):
|
|
2595
|
-
"""
|
|
2596
|
-
Set the cell to auto parallel mode.
|
|
2597
|
-
|
|
2598
|
-
Note:
|
|
2599
|
-
This interface is deprecated.
|
|
2600
|
-
"""
|
|
2601
|
-
logger.warning("'set_auto_parallel' function is deprecated.")
|
|
2602
|
-
|
|
2603
2562
|
def set_jit_config(self, jit_config):
|
|
2604
2563
|
"""
|
|
2605
2564
|
Set jit config for cell.
|
|
@@ -2645,25 +2604,38 @@ class Cell(Cell_):
|
|
|
2645
2604
|
raise ValueError(f"Negative 'fusion_size' {fusion_size} is invalid.")
|
|
2646
2605
|
Tensor._flatten_tensors(self.trainable_params(), fusion_size) # pylint: disable=W0212
|
|
2647
2606
|
|
|
2648
|
-
|
|
2607
|
+
@jit_forbidden_register
|
|
2608
|
+
def register_forward_pre_hook(self, hook_fn, with_kwargs=False):
|
|
2649
2609
|
"""
|
|
2650
2610
|
Register forward pre hook function for Cell object.
|
|
2651
2611
|
|
|
2612
|
+
The hook will be called before :func:`mindspore.nn.Cell.construct` is invoked.
|
|
2613
|
+
|
|
2614
|
+
The hook function should be one of the following signatures:
|
|
2615
|
+
|
|
2616
|
+
- `hook_fn(cell, args) -> None or new_args` , when `with_kwargs` is ``Flase`` .
|
|
2617
|
+
- `hook_fn(cell, args, kwargs) -> None or (new_args, new_kwargs)` , when `with_kwargs` is ``True`` .
|
|
2618
|
+
|
|
2619
|
+
where:
|
|
2620
|
+
|
|
2621
|
+
- `cell` (Cell): Cell object on which the hook is registered.
|
|
2622
|
+
- `args` (tuple): Positional arguments passed to the `construct` function.
|
|
2623
|
+
- `kwargs` (dict): Keyword arguments passed to the `construct` function. Only passed to `hook_fn` when
|
|
2624
|
+
`with_kwargs` is ``True`` .
|
|
2625
|
+
|
|
2652
2626
|
Note:
|
|
2653
|
-
- The `
|
|
2654
|
-
|
|
2655
|
-
`
|
|
2656
|
-
|
|
2657
|
-
|
|
2658
|
-
-
|
|
2659
|
-
hook_fn
|
|
2660
|
-
- In order to prevent running failed when switching to graph mode, it is not recommended to write it in the
|
|
2661
|
-
`construct` function of Cell object. In the pynative mode, if the `register_forward_pre_hook` function is
|
|
2662
|
-
called in the `construct` function of the Cell object, a hook function will be added at each run time of
|
|
2663
|
-
Cell object.
|
|
2627
|
+
- The `hook_fn` can modify the forward inputs by returning new inputs. If `with_kwargs` is ``Flase`` , a
|
|
2628
|
+
single value (whick will be wrapped into a tuple unless already a tuple) or a tuple of args should be
|
|
2629
|
+
returned. If `with_kwargs` is ``True`` , both `args` and `kwargs` should be returned.
|
|
2630
|
+
- In order to prevent running failed when switching to graph mode, it is not recommended to call it in the
|
|
2631
|
+
`construct` function of Cell object.
|
|
2632
|
+
- In the pynative mode, if this method is called inside the `construct` function of the Cell object, a
|
|
2633
|
+
`hook_fn` will be added at each run time of Cell object.
|
|
2664
2634
|
|
|
2665
2635
|
Args:
|
|
2666
2636
|
hook_fn (function): Python function. Forward pre hook function.
|
|
2637
|
+
with_kwargs (bool, optional): Specifies whether hook_fn will be passed the kwargs given to the `construct`
|
|
2638
|
+
function. Default: ``False`` .
|
|
2667
2639
|
|
|
2668
2640
|
Returns:
|
|
2669
2641
|
A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
|
|
@@ -2702,16 +2674,41 @@ class Cell(Cell_):
|
|
|
2702
2674
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
|
|
2703
2675
|
value= [ 2.00000000e+00]))
|
|
2704
2676
|
"""
|
|
2705
|
-
if context._get_mode() == context.GRAPH_MODE:
|
|
2706
|
-
return HookHandle()
|
|
2707
2677
|
check_hook_fn(hook_fn)
|
|
2708
|
-
handle = HookHandle(self._forward_pre_hook)
|
|
2678
|
+
handle = HookHandle(self._forward_pre_hook, extra_dict=self._forward_pre_hook_with_kwargs)
|
|
2709
2679
|
self._forward_pre_hook[handle.handle_id] = hook_fn
|
|
2680
|
+
if with_kwargs:
|
|
2681
|
+
self._forward_pre_hook_with_kwargs[handle.handle_id] = True
|
|
2682
|
+
_update_hook_version()
|
|
2710
2683
|
return handle
|
|
2711
2684
|
|
|
2712
|
-
|
|
2685
|
+
@jit_forbidden_register
|
|
2686
|
+
def _run_forward_pre_hook(self, args, kwargs):
|
|
2713
2687
|
"""
|
|
2714
2688
|
Running forward pre hook function registered on Cell object.
|
|
2689
|
+
"""
|
|
2690
|
+
for hook_id, hook_fn in self._forward_pre_hook.items():
|
|
2691
|
+
if hook_id in self._forward_pre_hook_with_kwargs:
|
|
2692
|
+
ret = hook_fn(self, args, kwargs)
|
|
2693
|
+
if ret is not None:
|
|
2694
|
+
if isinstance(ret, tuple) and len(ret) == 2:
|
|
2695
|
+
args, kwargs = ret
|
|
2696
|
+
else:
|
|
2697
|
+
raise RuntimeError(
|
|
2698
|
+
"forward pre hook with kwargs must return None or a tuple of (new_args, new_kwargs), "
|
|
2699
|
+
f"but got {ret}"
|
|
2700
|
+
)
|
|
2701
|
+
else:
|
|
2702
|
+
ret = hook_fn(self, args)
|
|
2703
|
+
if ret is not None:
|
|
2704
|
+
if not isinstance(ret, tuple):
|
|
2705
|
+
ret = (ret,)
|
|
2706
|
+
args = ret
|
|
2707
|
+
return args, kwargs
|
|
2708
|
+
|
|
2709
|
+
def _jit_forward_pre_hook(self, inputs):
|
|
2710
|
+
"""
|
|
2711
|
+
Compile forward pre hook function registered on Cell object.
|
|
2715
2712
|
|
|
2716
2713
|
Args:
|
|
2717
2714
|
inputs: The input objects of cell object.
|
|
@@ -2731,34 +2728,43 @@ class Cell(Cell_):
|
|
|
2731
2728
|
else:
|
|
2732
2729
|
forward_pre_hook_inputs = ret
|
|
2733
2730
|
|
|
2734
|
-
if
|
|
2735
|
-
|
|
2736
|
-
|
|
2737
|
-
|
|
2738
|
-
raise TypeError(
|
|
2739
|
-
"The forward pre hook return value size is {} not equal to input size {}".format(
|
|
2740
|
-
len(forward_pre_hook_inputs), len(inputs)))
|
|
2731
|
+
if len(forward_pre_hook_inputs) != len(inputs):
|
|
2732
|
+
raise TypeError(
|
|
2733
|
+
"The forward pre hook return value size is {} not equal to input size {}".format(
|
|
2734
|
+
len(forward_pre_hook_inputs), len(inputs)))
|
|
2741
2735
|
return forward_pre_hook_inputs
|
|
2742
2736
|
|
|
2743
|
-
|
|
2737
|
+
@jit_forbidden_register
|
|
2738
|
+
def register_forward_hook(self, hook_fn, with_kwargs=False):
|
|
2744
2739
|
"""
|
|
2745
|
-
|
|
2740
|
+
Register forward hook function for Cell object.
|
|
2741
|
+
|
|
2742
|
+
This hook will be called after :func:`mindspore.nn.Cell.construct` has computed an output.
|
|
2743
|
+
|
|
2744
|
+
The hook function should be one of the following signatures:
|
|
2745
|
+
|
|
2746
|
+
- `hook_fn(cell, args, output) -> None or new_output` , when `with_kwargs` is ``False`` .
|
|
2747
|
+
- `hook_fn(cell, args, kwargs, output) -> None or new_output` , when `with_kwargs` is ``True`` .
|
|
2748
|
+
|
|
2749
|
+
where:
|
|
2750
|
+
|
|
2751
|
+
- `cell` (Cell): Cell object on which the hook is registered.
|
|
2752
|
+
- `args` (tuple): Positional arguments passed to the `construct` function.
|
|
2753
|
+
- `kwargs` (dict): Keyword arguments passed to the `construct` function. Only passed to `hook_fn` when
|
|
2754
|
+
`with_kwargs` is ``True`` .
|
|
2755
|
+
- `output`: Output generated by the `construct` function.
|
|
2746
2756
|
|
|
2747
2757
|
Note:
|
|
2748
|
-
- The `
|
|
2749
|
-
-
|
|
2750
|
-
`
|
|
2751
|
-
|
|
2752
|
-
|
|
2753
|
-
- It should have the following signature:
|
|
2754
|
-
hook_fn(cell, inputs, output) -> new output object or none.
|
|
2755
|
-
- In order to prevent running failed when switching to graph mode, it is not recommended to write it in the
|
|
2756
|
-
`construct` function of Cell object. In the pynative mode, if the `register_forward_hook` function is
|
|
2757
|
-
called in the `construct` function of the Cell object, a hook function will be added at each run time of
|
|
2758
|
-
Cell object.
|
|
2758
|
+
- The `hook_fn` can modify the forward outputs by returning new outputs.
|
|
2759
|
+
- In order to prevent running failed when switching to graph mode, it is not recommended to call it in the
|
|
2760
|
+
`construct` function of Cell object.
|
|
2761
|
+
- In the pynative mode, if this method is called inside the `construct` function of the Cell object, a
|
|
2762
|
+
`hook_fn` will be added at each run time of Cell object.
|
|
2759
2763
|
|
|
2760
2764
|
Args:
|
|
2761
2765
|
hook_fn (function): Python function. Forward hook function.
|
|
2766
|
+
with_kwargs (bool, optional): Specifies whether hook_fn will be passed the kwargs given to the `construct`
|
|
2767
|
+
function. Default: ``False`` .
|
|
2762
2768
|
|
|
2763
2769
|
Returns:
|
|
2764
2770
|
A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
|
|
@@ -2801,16 +2807,17 @@ class Cell(Cell_):
|
|
|
2801
2807
|
"""
|
|
2802
2808
|
if self.has_bprop:
|
|
2803
2809
|
return HookHandle()
|
|
2804
|
-
if context._get_mode() == context.GRAPH_MODE:
|
|
2805
|
-
return HookHandle()
|
|
2806
2810
|
check_hook_fn(hook_fn)
|
|
2807
|
-
handle = HookHandle(self._forward_hook)
|
|
2811
|
+
handle = HookHandle(self._forward_hook, extra_dict=self._forward_hook_with_kwargs)
|
|
2808
2812
|
self._forward_hook[handle.handle_id] = hook_fn
|
|
2813
|
+
if with_kwargs:
|
|
2814
|
+
self._forward_hook_with_kwargs[handle.handle_id] = True
|
|
2815
|
+
_update_hook_version()
|
|
2809
2816
|
return handle
|
|
2810
2817
|
|
|
2811
|
-
def
|
|
2818
|
+
def _jit_forward_hook(self, inputs, output):
|
|
2812
2819
|
"""
|
|
2813
|
-
|
|
2820
|
+
Compile forward hook function registered on Cell object.
|
|
2814
2821
|
|
|
2815
2822
|
Args:
|
|
2816
2823
|
inputs: The input objects of Cell object.
|
|
@@ -2837,12 +2844,26 @@ class Cell(Cell_):
|
|
|
2837
2844
|
len(forward_hook_output), len(output)))
|
|
2838
2845
|
return forward_hook_output
|
|
2839
2846
|
|
|
2847
|
+
@jit_forbidden_register
|
|
2848
|
+
def _run_forward_hook(self, args, kwargs, output):
|
|
2849
|
+
"""
|
|
2850
|
+
Running forward hook function registered on Cell object.
|
|
2851
|
+
"""
|
|
2852
|
+
for hook_id, hook_fn in self._forward_hook.items():
|
|
2853
|
+
if hook_id in self._forward_hook_with_kwargs:
|
|
2854
|
+
ret = hook_fn(self, args, kwargs, output)
|
|
2855
|
+
else:
|
|
2856
|
+
ret = hook_fn(self, args, output)
|
|
2857
|
+
if ret is not None:
|
|
2858
|
+
output = ret
|
|
2859
|
+
return output
|
|
2860
|
+
|
|
2861
|
+
@jit_forbidden_register
|
|
2840
2862
|
def register_backward_pre_hook(self, hook_fn):
|
|
2841
2863
|
"""
|
|
2842
2864
|
Register the backward pre hook function.
|
|
2843
2865
|
|
|
2844
2866
|
Note:
|
|
2845
|
-
- The `register_backward_pre_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
|
|
2846
2867
|
- The 'hook_fn' must be defined as the following code.
|
|
2847
2868
|
`cell` is the Cell object. `grad_output` is the gradient passed to the Cell.
|
|
2848
2869
|
- The 'hook_fn' should have the following signature:
|
|
@@ -2891,44 +2912,17 @@ class Cell(Cell_):
|
|
|
2891
2912
|
>>> print(output)
|
|
2892
2913
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
|
|
2893
2914
|
"""
|
|
2894
|
-
if context._get_mode() == context.GRAPH_MODE:
|
|
2895
|
-
return HookHandle()
|
|
2896
2915
|
check_hook_fn(hook_fn)
|
|
2897
|
-
handle = HookHandle(self._backward_pre_hook)
|
|
2916
|
+
handle = HookHandle(self._backward_pre_hook, extra_dict=None)
|
|
2898
2917
|
self._backward_pre_hook[handle.handle_id] = hook_fn
|
|
2899
|
-
if self._cell_backward_pre_hook is None:
|
|
2918
|
+
if self._cell_backward_pre_hook is None: # pylint: disable=E0203
|
|
2900
2919
|
# Generate a CellBackwardHook prim, and add function for it
|
|
2901
2920
|
self._cell_backward_pre_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")",
|
|
2902
2921
|
self, self._backward_pre_hook)
|
|
2903
2922
|
self._cell_backward_pre_hook.register_backward_pre_hook()
|
|
2923
|
+
_update_hook_version()
|
|
2904
2924
|
return handle
|
|
2905
2925
|
|
|
2906
|
-
def _run_backward_pre_hook(self, outputs):
|
|
2907
|
-
"""
|
|
2908
|
-
Running backward pre hook function registered on Cell object.
|
|
2909
|
-
|
|
2910
|
-
Args:
|
|
2911
|
-
outputs: The output objects of cell object.
|
|
2912
|
-
|
|
2913
|
-
Returns:
|
|
2914
|
-
- **outputs** - New backward gradient or None.
|
|
2915
|
-
|
|
2916
|
-
Supported Platforms:
|
|
2917
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
2918
|
-
"""
|
|
2919
|
-
if isinstance(outputs, tuple):
|
|
2920
|
-
ret = self._cell_backward_pre_hook(*outputs)
|
|
2921
|
-
else:
|
|
2922
|
-
ret = self._cell_backward_pre_hook(outputs)
|
|
2923
|
-
if isinstance(outputs, tuple):
|
|
2924
|
-
if len(outputs) == 1:
|
|
2925
|
-
ret = (ret,)
|
|
2926
|
-
if len(ret) != len(outputs):
|
|
2927
|
-
raise TypeError(
|
|
2928
|
-
"The backward pre hook return value size is {} not equal to output size {}".format(
|
|
2929
|
-
len(ret), len(outputs)))
|
|
2930
|
-
return ret
|
|
2931
|
-
|
|
2932
2926
|
def get_extra_state(self) -> Any:
|
|
2933
2927
|
"""Return any extra state to include in the cell's state_dict.
|
|
2934
2928
|
|
|
@@ -2981,9 +2975,8 @@ class Cell(Cell_):
|
|
|
2981
2975
|
A handle that can be used to remove the added hook by calling
|
|
2982
2976
|
`handle.remove()`.
|
|
2983
2977
|
"""
|
|
2984
|
-
|
|
2985
|
-
handle =
|
|
2986
|
-
self._state_dict_hooks[handle.id] = hook
|
|
2978
|
+
handle = HookHandle(self._state_dict_hooks)
|
|
2979
|
+
self._state_dict_hooks[handle.handle_id] = hook
|
|
2987
2980
|
return handle
|
|
2988
2981
|
|
|
2989
2982
|
@jit_forbidden_register
|
|
@@ -3029,9 +3022,8 @@ class Cell(Cell_):
|
|
|
3029
3022
|
>>> print("extra_param" in net_state_dict)
|
|
3030
3023
|
True
|
|
3031
3024
|
"""
|
|
3032
|
-
|
|
3033
|
-
handle =
|
|
3034
|
-
self._state_dict_pre_hooks[handle.id] = hook
|
|
3025
|
+
handle = HookHandle(self._state_dict_pre_hooks)
|
|
3026
|
+
self._state_dict_pre_hooks[handle.handle_id] = hook
|
|
3035
3027
|
return handle
|
|
3036
3028
|
|
|
3037
3029
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
@@ -3116,7 +3108,6 @@ class Cell(Cell_):
|
|
|
3116
3108
|
OrderedDict([('param_a', Parameter (name=param_a, shape=(3,), dtype=Int64, requires_grad=True)), \
|
|
3117
3109
|
('buffer_a', Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]))])
|
|
3118
3110
|
"""
|
|
3119
|
-
# TODO: Remove `args` and the parsing logic when BC allows.
|
|
3120
3111
|
if args:
|
|
3121
3112
|
# DeprecationWarning is ignored by default
|
|
3122
3113
|
warnings.warn(
|
|
@@ -3169,7 +3160,7 @@ class Cell(Cell_):
|
|
|
3169
3160
|
|
|
3170
3161
|
It should have the following signature:
|
|
3171
3162
|
|
|
3172
|
-
hook(cell, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None
|
|
3163
|
+
hook(cell, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None
|
|
3173
3164
|
|
|
3174
3165
|
Args:
|
|
3175
3166
|
hook (Callable): The hook function before `load_state_dict` is called.
|
|
@@ -3178,9 +3169,8 @@ class Cell(Cell_):
|
|
|
3178
3169
|
A handle that can be used to remove the added hook by calling
|
|
3179
3170
|
`handle.remove()`.
|
|
3180
3171
|
"""
|
|
3181
|
-
|
|
3182
|
-
handle =
|
|
3183
|
-
self._load_state_dict_pre_hooks[handle.id] = hook
|
|
3172
|
+
handle = HookHandle(self._load_state_dict_pre_hooks)
|
|
3173
|
+
self._load_state_dict_pre_hooks[handle.handle_id] = hook
|
|
3184
3174
|
return handle
|
|
3185
3175
|
|
|
3186
3176
|
@jit_forbidden_register
|
|
@@ -3212,9 +3202,8 @@ class Cell(Cell_):
|
|
|
3212
3202
|
A handle that can be used to remove the added hook by calling
|
|
3213
3203
|
`handle.remove()`.
|
|
3214
3204
|
"""
|
|
3215
|
-
|
|
3216
|
-
handle =
|
|
3217
|
-
self._load_state_dict_post_hooks[handle.id] = hook
|
|
3205
|
+
handle = HookHandle(self._load_state_dict_post_hooks)
|
|
3206
|
+
self._load_state_dict_post_hooks[handle.handle_id] = hook
|
|
3218
3207
|
return handle
|
|
3219
3208
|
|
|
3220
3209
|
def _load_from_state_dict(
|
|
@@ -3450,12 +3439,12 @@ class Cell(Cell_):
|
|
|
3450
3439
|
)
|
|
3451
3440
|
return _IncompatibleKeys(missing_keys, unexpected_keys)
|
|
3452
3441
|
|
|
3442
|
+
@jit_forbidden_register
|
|
3453
3443
|
def register_backward_hook(self, hook_fn):
|
|
3454
3444
|
"""
|
|
3455
3445
|
Register the backward hook function.
|
|
3456
3446
|
|
|
3457
3447
|
Note:
|
|
3458
|
-
- The `register_backward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
|
|
3459
3448
|
- The 'hook_fn' must be defined as the following code.
|
|
3460
3449
|
`cell` is the registered Cell object. `grad_input` is the gradient computed and passed to
|
|
3461
3450
|
the next Cell or primitive, which can be return a new gradient or None. `grad_output` is the gradient
|
|
@@ -3507,65 +3496,17 @@ class Cell(Cell_):
|
|
|
3507
3496
|
>>> print(output)
|
|
3508
3497
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
|
|
3509
3498
|
"""
|
|
3510
|
-
if context._get_mode() == context.GRAPH_MODE:
|
|
3511
|
-
return HookHandle()
|
|
3512
3499
|
check_hook_fn(hook_fn)
|
|
3513
|
-
handle = HookHandle(self._backward_hook)
|
|
3500
|
+
handle = HookHandle(self._backward_hook, extra_dict=None)
|
|
3514
3501
|
self._backward_hook[handle.handle_id] = hook_fn
|
|
3515
|
-
if self._cell_backward_hook is None:
|
|
3502
|
+
if self._cell_backward_hook is None: # pylint: disable=E0203
|
|
3516
3503
|
# Generate a CellBackwardHook prim, and add function for it
|
|
3517
3504
|
self._cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")",
|
|
3518
3505
|
self, self._backward_hook)
|
|
3519
3506
|
self._cell_backward_hook.register_backward_hook()
|
|
3507
|
+
_update_hook_version()
|
|
3520
3508
|
return handle
|
|
3521
3509
|
|
|
3522
|
-
def _backward_hook_construct(self, *inputs, **kwargs):
|
|
3523
|
-
"""
|
|
3524
|
-
Backward hook construct method to replace original construct method.
|
|
3525
|
-
|
|
3526
|
-
Args:
|
|
3527
|
-
inputs: The input objects of Cell object.
|
|
3528
|
-
kwargs (dict): Dictionary of variable keyword parameters.
|
|
3529
|
-
|
|
3530
|
-
Returns:
|
|
3531
|
-
- **outputs** - The output objects of Cell object.
|
|
3532
|
-
|
|
3533
|
-
Supported Platforms:
|
|
3534
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
3535
|
-
"""
|
|
3536
|
-
# cell_backward_hook has CellBackwardHook op, so keep input args as they are.
|
|
3537
|
-
outputs = self._cell_backward_hook(*inputs)
|
|
3538
|
-
# If the inputs have more than two args, the outputs will also have more than two args and will be wrapped into
|
|
3539
|
-
# a tuple, so need to do unwrapping. If inputs is empty, we also need to unwrap it.
|
|
3540
|
-
# Because when output of runop method is one, it will not wrap a tuple, we need not unwrap it.
|
|
3541
|
-
is_need_unwrap = False
|
|
3542
|
-
if isinstance(outputs, tuple) and len(inputs) != 1:
|
|
3543
|
-
is_need_unwrap = True
|
|
3544
|
-
|
|
3545
|
-
if self._recompute_cell is not None:
|
|
3546
|
-
if is_need_unwrap:
|
|
3547
|
-
outputs = self._recompute_cell(*outputs, **kwargs)
|
|
3548
|
-
else:
|
|
3549
|
-
outputs = self._recompute_cell(outputs, **kwargs)
|
|
3550
|
-
elif self.has_bprop:
|
|
3551
|
-
if is_need_unwrap:
|
|
3552
|
-
outputs = self._call_custom_bprop(*outputs, **kwargs)
|
|
3553
|
-
else:
|
|
3554
|
-
outputs = self._call_custom_bprop(outputs, **kwargs)
|
|
3555
|
-
else:
|
|
3556
|
-
if is_need_unwrap:
|
|
3557
|
-
outputs = self.construct(*outputs, **kwargs)
|
|
3558
|
-
else:
|
|
3559
|
-
outputs = self.construct(outputs, **kwargs)
|
|
3560
|
-
if isinstance(outputs, tuple):
|
|
3561
|
-
new_outputs = self._cell_backward_hook(*outputs)
|
|
3562
|
-
else:
|
|
3563
|
-
new_outputs = self._cell_backward_hook(outputs)
|
|
3564
|
-
# if outputs is (X,) and new_outpus is X
|
|
3565
|
-
if isinstance(outputs, tuple) and len(outputs) == 1:
|
|
3566
|
-
new_outputs = (new_outputs,)
|
|
3567
|
-
return new_outputs
|
|
3568
|
-
|
|
3569
3510
|
def set_param_ps(self, recurse=True, init_in_server=False):
|
|
3570
3511
|
"""
|
|
3571
3512
|
Set whether the trainable parameters are updated by parameter server and whether the
|
|
@@ -3584,12 +3525,6 @@ class Cell(Cell_):
|
|
|
3584
3525
|
for param in params:
|
|
3585
3526
|
param.set_param_ps(init_in_server)
|
|
3586
3527
|
|
|
3587
|
-
@deprecated("1.8", "set_param_fl")
|
|
3588
|
-
def set_param_fl(self, push_to_server=False, pull_from_server=False, requires_aggr=True):
|
|
3589
|
-
params = self.parameters_and_names()
|
|
3590
|
-
for param in params:
|
|
3591
|
-
param[1].set_param_fl(push_to_server, pull_from_server, requires_aggr)
|
|
3592
|
-
|
|
3593
3528
|
def set_comm_fusion(self, fusion_type, recurse=True):
|
|
3594
3529
|
"""
|
|
3595
3530
|
Set `comm_fusion` for all the parameters in this cell. Please refer to the description of
|
|
@@ -3650,7 +3585,7 @@ class Cell(Cell_):
|
|
|
3650
3585
|
"""
|
|
3651
3586
|
Validator.check_bool(mode)
|
|
3652
3587
|
Validator.check_bool(output_recompute)
|
|
3653
|
-
if not self._has_config_recompute:
|
|
3588
|
+
if not self._has_config_recompute: # pylint: disable=E0203
|
|
3654
3589
|
self._has_config_recompute = True
|
|
3655
3590
|
else:
|
|
3656
3591
|
logger.info("The recompute interface can be configured only once."
|
|
@@ -3693,12 +3628,12 @@ class Cell(Cell_):
|
|
|
3693
3628
|
introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode.
|
|
3694
3629
|
Default: ``False`` .
|
|
3695
3630
|
"""
|
|
3696
|
-
if context.
|
|
3631
|
+
if context._get_mode() == context.PYNATIVE_MODE:
|
|
3697
3632
|
self._recompute_cell = recompute_registry.get()(self.construct)
|
|
3698
3633
|
self._recompute()
|
|
3699
3634
|
if 'mp_comm_recompute' in kwargs.keys():
|
|
3700
3635
|
self._mp_comm_recompute(kwargs.get('mp_comm_recompute', False))
|
|
3701
|
-
if 'parallel_optimizer_comm_recompute' in kwargs
|
|
3636
|
+
if 'parallel_optimizer_comm_recompute' in kwargs:
|
|
3702
3637
|
if kwargs.get('parallel_optimizer_comm_recompute', False):
|
|
3703
3638
|
logger.warning("Currently, the communication operator allgathers introduced by optimizer shard "
|
|
3704
3639
|
"is replaced with zero3.")
|
|
@@ -3711,38 +3646,6 @@ class Cell(Cell_):
|
|
|
3711
3646
|
"the key kwargs must be 'mp_comm_recompute', "
|
|
3712
3647
|
"'parallel_optimizer_comm_recompute', 'recompute_slice_activation'" % key)
|
|
3713
3648
|
|
|
3714
|
-
@deprecated("2.3", "infer_param_pipeline_stage")
|
|
3715
|
-
def infer_param_pipeline_stage(self):
|
|
3716
|
-
"""
|
|
3717
|
-
Infer pipeline stages of all parameters in the cell.
|
|
3718
|
-
|
|
3719
|
-
Note:
|
|
3720
|
-
- The interface is deprecated from version 2.3 and will be removed in a future version.
|
|
3721
|
-
|
|
3722
|
-
Returns:
|
|
3723
|
-
The params belong to current stage in pipeline parallel.
|
|
3724
|
-
|
|
3725
|
-
Raises:
|
|
3726
|
-
RuntimeError: If there is a parameter does not belong to any stage.
|
|
3727
|
-
"""
|
|
3728
|
-
from mindspore.parallel._utils import _get_global_rank, _get_device_num
|
|
3729
|
-
logger.warning(f"This interface may be deleted in the future.")
|
|
3730
|
-
stage_num = context.get_auto_parallel_context("pipeline_stages")
|
|
3731
|
-
device_num = _get_device_num()
|
|
3732
|
-
rank_id = _get_global_rank()
|
|
3733
|
-
per_stage_devices = device_num // stage_num
|
|
3734
|
-
current_stage = rank_id // per_stage_devices
|
|
3735
|
-
params = []
|
|
3736
|
-
for param in self.trainable_params():
|
|
3737
|
-
if not param._pipeline_stage_list: # pylint: disable=W0212
|
|
3738
|
-
raise RuntimeError("For 'infer_param_pipeline_stage', the parameter {} does not belong to any stage, "
|
|
3739
|
-
"please check whether the cell where the param locates has been set "
|
|
3740
|
-
"'pipeline_stage'. Otherwise, the parameter should use 'add_pipeline_stage' "
|
|
3741
|
-
"to add its stage information".format(param.name))
|
|
3742
|
-
if current_stage in param._pipeline_stage_list:
|
|
3743
|
-
params.append(param)
|
|
3744
|
-
return params
|
|
3745
|
-
|
|
3746
3649
|
def place(self, role, rank_id):
|
|
3747
3650
|
"""
|
|
3748
3651
|
Set the label for all operators in this cell.
|
|
@@ -3772,19 +3675,6 @@ class Cell(Cell_):
|
|
|
3772
3675
|
for op in all_ops:
|
|
3773
3676
|
op.place(role, rank_id)
|
|
3774
3677
|
|
|
3775
|
-
def _mixed_precision_cast(self, inputs):
|
|
3776
|
-
mixed_type = self.get_mixed_precision_type()
|
|
3777
|
-
if mixed_type == MixedPrecisionType.NOTSET:
|
|
3778
|
-
return inputs
|
|
3779
|
-
if mixed_type == MixedPrecisionType.FP16:
|
|
3780
|
-
cast_type = mstype.float16
|
|
3781
|
-
elif mixed_type == MixedPrecisionType.BF16:
|
|
3782
|
-
cast_type = mstype.bfloat16
|
|
3783
|
-
else:
|
|
3784
|
-
cast_type = mstype.float32
|
|
3785
|
-
cast_inputs = self._cast_mixed_precision_inputs(inputs, cast_type)
|
|
3786
|
-
return cast_inputs
|
|
3787
|
-
|
|
3788
3678
|
def _get_attr_from_cell(self, network):
|
|
3789
3679
|
if not isinstance(network, Cell):
|
|
3790
3680
|
return
|
|
@@ -3793,92 +3683,70 @@ class Cell(Cell_):
|
|
|
3793
3683
|
if hasattr(network, "_amp_level"):
|
|
3794
3684
|
self._amp_level = getattr(network, "_amp_level")
|
|
3795
3685
|
|
|
3796
|
-
def
|
|
3686
|
+
def _set_jit_graph_name(self, key):
|
|
3687
|
+
"""
|
|
3688
|
+
Set jit graph name.
|
|
3797
3689
|
"""
|
|
3798
|
-
|
|
3690
|
+
self._jit_graph_name = key
|
|
3799
3691
|
|
|
3800
|
-
|
|
3801
|
-
|
|
3692
|
+
def _jit_backward_pre_hook(self, grad_output):
|
|
3693
|
+
new_grad_output = grad_output
|
|
3694
|
+
if not isinstance(grad_output, tuple):
|
|
3695
|
+
new_grad_output = (grad_output,)
|
|
3802
3696
|
|
|
3803
|
-
|
|
3804
|
-
|
|
3805
|
-
|
|
3806
|
-
|
|
3807
|
-
|
|
3808
|
-
|
|
3809
|
-
|
|
3810
|
-
|
|
3811
|
-
|
|
3812
|
-
|
|
3813
|
-
- The `backward_hook` should have the following signature:
|
|
3814
|
-
backward_hook(parameters) -> New gradients.
|
|
3697
|
+
for fn in self._backward_pre_hook.values():
|
|
3698
|
+
ret = fn(self, new_grad_output)
|
|
3699
|
+
if ret is not None:
|
|
3700
|
+
if not isinstance(ret, tuple):
|
|
3701
|
+
output = (ret,)
|
|
3702
|
+
else:
|
|
3703
|
+
output = ret
|
|
3704
|
+
else:
|
|
3705
|
+
output = ops.Depend()(new_grad_output, ret)
|
|
3706
|
+
new_grad_output = output
|
|
3815
3707
|
|
|
3816
|
-
|
|
3817
|
-
|
|
3818
|
-
|
|
3819
|
-
|
|
3708
|
+
if not isinstance(grad_output, tuple):
|
|
3709
|
+
if len(new_grad_output) == 1:
|
|
3710
|
+
return new_grad_output[0]
|
|
3711
|
+
raise TypeError(
|
|
3712
|
+
"The backward pre hook return value size is {} not equal to input size 1".format(
|
|
3713
|
+
len(new_grad_output)))
|
|
3820
3714
|
|
|
3821
|
-
|
|
3822
|
-
|
|
3715
|
+
if len(new_grad_output) != len(grad_output):
|
|
3716
|
+
raise TypeError(
|
|
3717
|
+
"The backward pre hook return value size is {} not equal to input size {}".format(
|
|
3718
|
+
len(new_grad_output), len(grad_output)))
|
|
3823
3719
|
|
|
3824
|
-
|
|
3825
|
-
RuntimeError: If the `forward_hook` or `backward_hook ` has unspoorted syntax under GRAPH MODE.
|
|
3826
|
-
TypeError: If the `forward_hook` or `backward_hook` is not defined as required.
|
|
3720
|
+
return new_grad_output
|
|
3827
3721
|
|
|
3828
|
-
|
|
3829
|
-
|
|
3722
|
+
def _jit_backward_hook(self, grad_input, grad_output):
|
|
3723
|
+
backward_hook_input = grad_input
|
|
3724
|
+
backward_hook_output = grad_output
|
|
3725
|
+
if not isinstance(grad_input, tuple):
|
|
3726
|
+
backward_hook_input = (grad_input,)
|
|
3727
|
+
if not isinstance(grad_output, tuple):
|
|
3728
|
+
backward_hook_output = (grad_output,)
|
|
3830
3729
|
|
|
3831
|
-
|
|
3832
|
-
|
|
3833
|
-
|
|
3834
|
-
|
|
3835
|
-
|
|
3836
|
-
|
|
3837
|
-
|
|
3838
|
-
|
|
3839
|
-
|
|
3840
|
-
|
|
3841
|
-
|
|
3842
|
-
>>> def gradient_hook(gradients):
|
|
3843
|
-
... print("--- enter gradient hook ---")
|
|
3844
|
-
... outs = []
|
|
3845
|
-
... for name, gradient in gradients:
|
|
3846
|
-
... print(name, gradient)
|
|
3847
|
-
... outs.append(gradient * 2) # double gradient
|
|
3848
|
-
... print("--- leave gradient hook ---")
|
|
3849
|
-
... return outs
|
|
3850
|
-
...
|
|
3851
|
-
>>> class Net(nn.Cell):
|
|
3852
|
-
... def __init__(self)
|
|
3853
|
-
... super(Net, self).__init__()
|
|
3854
|
-
... self.w = Parameter(Tensor(np.array([3.0], np.float32)), name='w')
|
|
3855
|
-
... def construct(self, x):
|
|
3856
|
-
... return self.w * x
|
|
3857
|
-
...
|
|
3858
|
-
>>> grad = ops.GradOperation(get_by_list=True)
|
|
3859
|
-
>>> net = Net()
|
|
3860
|
-
>>> net._register_parameters_hook(forward_hook=parameter_hook, backward_hook=gradient_hook)
|
|
3861
|
-
>>> x = Tensor(np.array([4.0]).astype(np.float32))
|
|
3862
|
-
>>> output = grad(net, net.trainable_params())(x)
|
|
3863
|
-
--- enter parameter hook ---
|
|
3864
|
-
w
|
|
3865
|
-
Tensor(shape=[1], dtype=Float32, value=[ 3.00000000e+00])
|
|
3866
|
-
--- leave parameter hook ---
|
|
3867
|
-
--- enter gradient hook ---
|
|
3868
|
-
w
|
|
3869
|
-
Tensor(shape=[1], dtype=Float32, value=[ 4.00000000e+00])
|
|
3870
|
-
--- leave gradient hook ---
|
|
3871
|
-
>>> print("doubled grad: ", output)
|
|
3872
|
-
doubled grad: (Tensor(shape=[1], dtype=Float32, value=[ 8.00000000e+00]),)
|
|
3873
|
-
"""
|
|
3874
|
-
if not all:
|
|
3875
|
-
self._parameters_forward_hook = forward_hook
|
|
3876
|
-
self._parameters_backward_hook = backward_hook
|
|
3877
|
-
else:
|
|
3878
|
-
for _, cell in self.cells_and_names():
|
|
3879
|
-
cell._parameters_forward_hook = forward_hook
|
|
3880
|
-
cell._parameters_backward_hook = backward_hook
|
|
3730
|
+
for fn in self._backward_hook.values():
|
|
3731
|
+
ret = fn(self, backward_hook_input, backward_hook_output)
|
|
3732
|
+
if ret is not None:
|
|
3733
|
+
if not isinstance(ret, tuple):
|
|
3734
|
+
output = (ret,)
|
|
3735
|
+
else:
|
|
3736
|
+
output = ret
|
|
3737
|
+
else:
|
|
3738
|
+
output = ops.Depend()(backward_hook_input, ret)
|
|
3739
|
+
|
|
3740
|
+
backward_hook_input = output
|
|
3881
3741
|
|
|
3742
|
+
if not isinstance(grad_input, tuple):
|
|
3743
|
+
return backward_hook_input[0]
|
|
3744
|
+
|
|
3745
|
+
if len(backward_hook_input) != len(grad_input):
|
|
3746
|
+
raise TypeError(
|
|
3747
|
+
"The backward hook return value size is {} not equal to input size {}".format(
|
|
3748
|
+
len(backward_hook_input), len(grad_input)))
|
|
3749
|
+
return backward_hook_input
|
|
3882
3750
|
|
|
3883
3751
|
class GraphCell(Cell):
|
|
3884
3752
|
"""
|