mindspore 2.6.0rc1__cp310-cp310-win_amd64.whl → 2.7.0rc1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +1 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +40 -9
- mindspore/{_deprecated → _extends/optimize}/__init__.py +9 -3
- mindspore/_extends/optimize/cell_utils.py +96 -0
- mindspore/_extends/parse/__init__.py +2 -2
- mindspore/_extends/parse/compile_config.py +44 -22
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -1
- mindspore/_extends/parse/parser.py +37 -62
- mindspore/_extends/parse/resources.py +39 -0
- mindspore/_extends/parse/standard_method.py +43 -13
- mindspore/_extends/parse/trope.py +8 -1
- mindspore/_extends/pijit/__init__.py +1 -2
- mindspore/amp.py +4 -4
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +4 -4
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +27 -2
- mindspore/common/_grad_function.py +2 -1
- mindspore/common/_pijit_context.py +28 -7
- mindspore/common/_stub_tensor.py +1 -209
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +77 -16
- mindspore/common/api.py +238 -113
- mindspore/common/dtype.py +21 -11
- mindspore/common/dump.py +10 -15
- mindspore/common/generator.py +5 -3
- mindspore/common/hook_handle.py +11 -2
- mindspore/common/jit_config.py +1 -1
- mindspore/common/jit_trace.py +84 -105
- mindspore/common/parameter.py +26 -12
- mindspore/common/recompute.py +3 -3
- mindspore/common/sparse_tensor.py +0 -3
- mindspore/common/symbol.py +0 -1
- mindspore/common/tensor.py +81 -81
- mindspore/communication/_comm_helper.py +46 -4
- mindspore/communication/management.py +79 -7
- mindspore/context.py +58 -40
- mindspore/dataset/core/config.py +3 -3
- mindspore/dataset/engine/datasets.py +20 -7
- mindspore/dataset/engine/datasets_user_defined.py +33 -3
- mindspore/dataset/engine/iterators.py +2 -2
- mindspore/dataset/engine/obs/config_loader.py +2 -2
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +8 -0
- mindspore/dataset/transforms/py_transforms.py +7 -3
- mindspore/dataset/transforms/transforms.py +7 -3
- mindspore/dataset/vision/validators.py +1 -0
- mindspore/device_context/ascend/device.py +1 -1
- mindspore/device_context/gpu/__init__.py +2 -2
- mindspore/device_context/gpu/device.py +1 -1
- mindspore/device_context/gpu/op_precision.py +4 -2
- mindspore/device_context/gpu/op_tuning.py +6 -3
- mindspore/device_manager.py +16 -9
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +3 -7
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/optim/adadelta.py +13 -20
- mindspore/experimental/optim/adagrad.py +15 -22
- mindspore/experimental/optim/adam.py +17 -24
- mindspore/experimental/optim/adamax.py +14 -22
- mindspore/experimental/optim/adamw.py +28 -34
- mindspore/experimental/optim/asgd.py +15 -25
- mindspore/experimental/optim/lr_scheduler.py +27 -45
- mindspore/experimental/optim/nadam.py +14 -24
- mindspore/experimental/optim/optimizer.py +13 -23
- mindspore/experimental/optim/radam.py +18 -24
- mindspore/experimental/optim/rmsprop.py +14 -25
- mindspore/experimental/optim/rprop.py +15 -26
- mindspore/experimental/optim/sgd.py +9 -19
- mindspore/hal/__init__.py +4 -4
- mindspore/hal/contiguous_tensors_handle.py +2 -2
- mindspore/hal/memory.py +27 -7
- mindspore/include/api/cell.h +37 -1
- mindspore/include/api/delegate.h +10 -0
- mindspore/include/api/model.h +3 -0
- mindspore/include/api/types.h +2 -2
- mindspore/include/c_api/model_c.h +0 -58
- mindspore/include/c_api/tensor_c.h +0 -26
- mindspore/include/dataset/vision_ascend.h +1 -1
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/cifar10.py +60 -11
- mindspore/mindrecord/tools/cifar10_to_mr.py +5 -0
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mindspore_ops_host.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +6 -46
- mindspore/mint/distributed/__init__.py +1 -0
- mindspore/mint/distributed/distributed.py +212 -9
- mindspore/mint/nn/__init__.py +1 -1
- mindspore/mint/nn/functional.py +53 -6
- mindspore/mint/nn/layer/_functions.py +164 -294
- mindspore/mint/nn/layer/activation.py +8 -6
- mindspore/mint/nn/layer/conv.py +137 -101
- mindspore/mint/nn/layer/normalization.py +8 -22
- mindspore/mint/optim/adam.py +19 -18
- mindspore/mint/optim/adamw.py +14 -8
- mindspore/mint/optim/sgd.py +5 -5
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/cell.py +328 -502
- mindspore/nn/grad/cell_grad.py +11 -12
- mindspore/nn/layer/activation.py +32 -34
- mindspore/nn/layer/basic.py +67 -64
- mindspore/nn/layer/channel_shuffle.py +4 -4
- mindspore/nn/layer/combined.py +4 -2
- mindspore/nn/layer/conv.py +117 -110
- mindspore/nn/layer/dense.py +9 -7
- mindspore/nn/layer/embedding.py +50 -52
- mindspore/nn/layer/image.py +37 -39
- mindspore/nn/layer/math.py +111 -112
- mindspore/nn/layer/normalization.py +56 -44
- mindspore/nn/layer/pooling.py +58 -63
- mindspore/nn/layer/rnn_cells.py +33 -33
- mindspore/nn/layer/rnns.py +56 -56
- mindspore/nn/layer/thor_layer.py +74 -73
- mindspore/nn/layer/transformer.py +11 -1
- mindspore/nn/learning_rate_schedule.py +20 -20
- mindspore/nn/loss/loss.py +79 -81
- mindspore/nn/optim/adam.py +3 -3
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -0
- mindspore/nn/optim/optimizer.py +1 -1
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -1
- mindspore/nn/probability/distribution/poisson.py +2 -1
- mindspore/nn/sparse/sparse.py +3 -3
- mindspore/nn/wrap/cell_wrapper.py +34 -37
- mindspore/nn/wrap/grad_reducer.py +37 -37
- mindspore/nn/wrap/loss_scale.py +72 -74
- mindspore/numpy/array_creations.py +5 -5
- mindspore/numpy/fft.py +1 -1
- mindspore/numpy/math_ops.py +5 -5
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +51 -13
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -0
- mindspore/ops/_vmap/vmap_array_ops.py +31 -13
- mindspore/ops/_vmap/vmap_nn_ops.py +8 -16
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +42 -11
- mindspore/ops/auto_generate/gen_extend_func.py +23 -141
- mindspore/ops/auto_generate/gen_ops_def.py +727 -321
- mindspore/ops/auto_generate/gen_ops_prim.py +1721 -984
- mindspore/ops/auto_generate/pyboost_inner_prim.py +31 -1
- mindspore/ops/composite/__init__.py +10 -0
- mindspore/ops/composite/base.py +8 -4
- mindspore/ops/composite/multitype_ops/__init__.py +12 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +133 -109
- mindspore/ops/composite/multitype_ops/add_impl.py +70 -2
- mindspore/ops/composite/multitype_ops/div_impl.py +49 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +29 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +11 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +5 -3
- mindspore/ops/composite/multitype_ops/mul_impl.py +49 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +57 -0
- mindspore/ops/composite/multitype_ops/sub_impl.py +34 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +14 -0
- mindspore/ops/function/__init__.py +3 -1
- mindspore/ops/function/_add_attr_func.py +11 -6
- mindspore/ops/function/array_func.py +9 -96
- mindspore/ops/function/debug_func.py +4 -3
- mindspore/ops/function/grad/grad_func.py +1 -1
- mindspore/ops/function/math_func.py +33 -540
- mindspore/ops/function/nn_func.py +28 -74
- mindspore/ops/function/other_func.py +4 -1
- mindspore/ops/function/random_func.py +44 -5
- mindspore/ops/function/vmap_func.py +2 -1
- mindspore/ops/functional.py +2 -3
- mindspore/ops/functional_overload.py +571 -6
- mindspore/ops/op_info_register.py +21 -0
- mindspore/ops/operations/__init__.py +16 -11
- mindspore/ops/operations/_custom_ops_utils.py +689 -34
- mindspore/ops/operations/_inner_ops.py +3 -6
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +2 -2
- mindspore/ops/operations/comm_ops.py +185 -26
- mindspore/ops/operations/custom_ops.py +294 -174
- mindspore/ops/operations/debug_ops.py +59 -4
- mindspore/ops/operations/image_ops.py +13 -13
- mindspore/ops/operations/manually_defined/ops_def.py +15 -16
- mindspore/ops/operations/math_ops.py +3 -4
- mindspore/ops/operations/nn_ops.py +7 -39
- mindspore/ops/primitive.py +6 -10
- mindspore/ops/tensor_method.py +47 -8
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +1 -1
- mindspore/ops_generate/api/functional_map_cpp_generator.py +10 -9
- mindspore/ops_generate/api/functions_cc_generator.py +58 -10
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +1 -1
- mindspore/ops_generate/common/base_generator.py +14 -0
- mindspore/ops_generate/common/gen_constants.py +8 -3
- mindspore/ops_generate/common/gen_utils.py +0 -19
- mindspore/ops_generate/common/op_proto.py +11 -4
- mindspore/ops_generate/common/template.py +88 -11
- mindspore/ops_generate/gen_ops.py +1 -1
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +4 -4
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +0 -3
- mindspore/ops_generate/op_def/ops_name_h_generator.py +0 -3
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +0 -4
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -2
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +49 -8
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +2 -2
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +31 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +98 -72
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +70 -273
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +14 -6
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +316 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +5 -3
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_internal_functions_cpp_generator.py +76 -0
- mindspore/ops_generate/pyboost/pyboost_internal_functions_h_generator.py +76 -0
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +125 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +4 -3
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +348 -61
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_utils.py +118 -9
- mindspore/ops_generate/tensor_py_cc_generator.py +1 -24
- mindspore/parallel/_auto_parallel_context.py +11 -8
- mindspore/parallel/_cell_wrapper.py +113 -45
- mindspore/parallel/_parallel_serialization.py +1 -1
- mindspore/parallel/_ps_context.py +4 -6
- mindspore/parallel/_tensor.py +167 -12
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/transformer.py +13 -8
- mindspore/parallel/auto_parallel.py +14 -7
- mindspore/parallel/checkpoint_convert.py +3 -3
- mindspore/parallel/checkpoint_transform.py +11 -7
- mindspore/parallel/cluster/process_entity/_api.py +84 -48
- mindspore/parallel/cluster/process_entity/_utils.py +95 -7
- mindspore/parallel/cluster/run.py +43 -4
- mindspore/parallel/function/__init__.py +8 -1
- mindspore/parallel/function/reshard_func.py +6 -7
- mindspore/parallel/nn/__init__.py +15 -2
- mindspore/parallel/nn/parallel_cell_wrapper.py +9 -10
- mindspore/parallel/nn/parallel_grad_reducer.py +7 -6
- mindspore/parallel/shard.py +3 -4
- mindspore/parallel/transform_safetensors.py +463 -174
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -7
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +3 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +12 -6
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +4 -4
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +4 -1
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +2 -1
- mindspore/profiler/analysis/task_manager.py +1 -1
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +5 -1
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +2 -1
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +42 -22
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +3 -2
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +9 -5
- mindspore/profiler/analysis/viewer/ms_operator_details_viewer.py +132 -0
- mindspore/profiler/common/constant.py +16 -0
- mindspore/profiler/common/profiler_context.py +25 -27
- mindspore/profiler/common/profiler_info.py +0 -16
- mindspore/profiler/common/profiler_op_analyse.py +235 -0
- mindspore/profiler/common/profiler_output_path.py +23 -8
- mindspore/profiler/common/profiler_parameters.py +128 -35
- mindspore/profiler/dynamic_profile/__init__.py +0 -0
- mindspore/profiler/dynamic_profile/dynamic_monitor_proxy.py +39 -0
- mindspore/profiler/dynamic_profile/dynamic_profiler_config_context.py +666 -0
- mindspore/profiler/dynamic_profile/dynamic_profiler_utils.py +62 -0
- mindspore/profiler/dynamic_profiler.py +305 -314
- mindspore/profiler/envprofiler.py +12 -7
- mindspore/profiler/experimental_config.py +96 -6
- mindspore/profiler/mstx.py +33 -12
- mindspore/profiler/platform/__init__.py +2 -3
- mindspore/profiler/platform/npu_profiler.py +29 -19
- mindspore/profiler/profiler.py +35 -19
- mindspore/profiler/profiler_action_controller.py +64 -76
- mindspore/profiler/schedule.py +10 -4
- mindspore/rewrite/common/config.py +1 -0
- mindspore/rewrite/common/namer.py +1 -0
- mindspore/rewrite/common/namespace.py +1 -0
- mindspore/rewrite/node/node.py +31 -11
- mindspore/rewrite/parsers/assign_parser.py +1 -1
- mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +7 -10
- mindspore/runtime/__init__.py +5 -5
- mindspore/runtime/event.py +10 -4
- mindspore/runtime/executor.py +60 -45
- mindspore/runtime/memory.py +30 -32
- mindspore/runtime/thread_bind_core.py +298 -164
- mindspore/safeguard/rewrite_obfuscation.py +12 -13
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/_utils.py +14 -4
- mindspore/train/amp.py +43 -20
- mindspore/train/callback/__init__.py +5 -5
- mindspore/train/callback/_checkpoint.py +3 -6
- mindspore/train/callback/_flops_collector.py +1 -1
- mindspore/train/callback/_landscape.py +0 -1
- mindspore/train/callback/_train_fault_tolerance.py +97 -16
- mindspore/train/data_sink.py +11 -2
- mindspore/train/dataset_helper.py +9 -0
- mindspore/train/model.py +135 -55
- mindspore/train/serialization.py +133 -111
- mindspore/train/summary/summary_record.py +13 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +3 -2
- mindspore/utils/dryrun.py +0 -6
- mindspore/utils/runtime_execution_order_check.py +163 -77
- mindspore/utils/sdc_detect.py +68 -0
- mindspore/utils/utils.py +6 -9
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0rc1.dist-info}/METADATA +5 -4
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0rc1.dist-info}/RECORD +356 -394
- mindspore/_deprecated/jit.py +0 -198
- mindspore/experimental/es/__init__.py +0 -22
- mindspore/experimental/es/embedding_service.py +0 -891
- mindspore/experimental/es/embedding_service_layer.py +0 -581
- mindspore/profiler/parser/__init__.py +0 -14
- mindspore/profiler/parser/aicpu_data_parser.py +0 -272
- mindspore/profiler/parser/ascend_analysis/__init__.py +0 -14
- mindspore/profiler/parser/ascend_analysis/constant.py +0 -71
- mindspore/profiler/parser/ascend_analysis/file_manager.py +0 -180
- mindspore/profiler/parser/ascend_analysis/function_event.py +0 -185
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +0 -136
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +0 -131
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +0 -104
- mindspore/profiler/parser/ascend_analysis/path_manager.py +0 -313
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +0 -123
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +0 -86
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +0 -75
- mindspore/profiler/parser/ascend_cluster_generator.py +0 -116
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -314
- mindspore/profiler/parser/ascend_flops_generator.py +0 -116
- mindspore/profiler/parser/ascend_fpbp_generator.py +0 -82
- mindspore/profiler/parser/ascend_hccl_generator.py +0 -271
- mindspore/profiler/parser/ascend_integrate_generator.py +0 -42
- mindspore/profiler/parser/ascend_memory_generator.py +0 -185
- mindspore/profiler/parser/ascend_msprof_exporter.py +0 -282
- mindspore/profiler/parser/ascend_msprof_generator.py +0 -187
- mindspore/profiler/parser/ascend_op_generator.py +0 -334
- mindspore/profiler/parser/ascend_steptrace_generator.py +0 -94
- mindspore/profiler/parser/ascend_timeline_generator.py +0 -545
- mindspore/profiler/parser/base_timeline_generator.py +0 -483
- mindspore/profiler/parser/container.py +0 -229
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +0 -697
- mindspore/profiler/parser/flops_parser.py +0 -531
- mindspore/profiler/parser/framework_enum.py +0 -111
- mindspore/profiler/parser/framework_parser.py +0 -464
- mindspore/profiler/parser/framework_struct.py +0 -61
- mindspore/profiler/parser/gpu_analysis/__init__.py +0 -14
- mindspore/profiler/parser/gpu_analysis/function_event.py +0 -44
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +0 -89
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +0 -72
- mindspore/profiler/parser/hccl_parser.py +0 -573
- mindspore/profiler/parser/hwts_log_parser.py +0 -122
- mindspore/profiler/parser/integrator.py +0 -526
- mindspore/profiler/parser/memory_usage_parser.py +0 -277
- mindspore/profiler/parser/minddata_analyzer.py +0 -800
- mindspore/profiler/parser/minddata_parser.py +0 -186
- mindspore/profiler/parser/minddata_pipeline_parser.py +0 -299
- mindspore/profiler/parser/op_intermediate_parser.py +0 -149
- mindspore/profiler/parser/optime_parser.py +0 -250
- mindspore/profiler/parser/profiler_info.py +0 -213
- mindspore/profiler/parser/step_trace_parser.py +0 -666
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0rc1.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py
CHANGED
|
@@ -15,6 +15,10 @@
|
|
|
15
15
|
"""cell"""
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
|
|
18
|
+
__all__ = [
|
|
19
|
+
"register_cell_buffer_registration_hook",
|
|
20
|
+
]
|
|
21
|
+
|
|
18
22
|
import inspect
|
|
19
23
|
import os
|
|
20
24
|
import time
|
|
@@ -24,7 +28,6 @@ from collections import OrderedDict, namedtuple
|
|
|
24
28
|
from typing import (
|
|
25
29
|
Dict,
|
|
26
30
|
Optional,
|
|
27
|
-
Set,
|
|
28
31
|
Callable,
|
|
29
32
|
List,
|
|
30
33
|
Tuple,
|
|
@@ -34,36 +37,29 @@ from typing import (
|
|
|
34
37
|
Mapping
|
|
35
38
|
)
|
|
36
39
|
|
|
40
|
+
import weakref
|
|
37
41
|
import mindspore as ms
|
|
38
42
|
from mindspore._checkparam import args_type_check, check_hook_fn
|
|
39
43
|
from mindspore.common._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic
|
|
40
44
|
from mindspore import log as logger
|
|
41
|
-
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
|
|
42
45
|
from mindspore.common.hook_handle import HookHandle
|
|
43
|
-
from mindspore.context import ParallelMode
|
|
44
46
|
from mindspore import context
|
|
45
47
|
from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
|
|
46
48
|
from mindspore import _checkparam as Validator
|
|
47
49
|
from mindspore.common import dtype as mstype
|
|
48
50
|
from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache, \
|
|
49
|
-
_no_grad
|
|
50
|
-
from mindspore.common.api import _convert_python_data
|
|
51
|
+
_no_grad, _get_mutable_flags
|
|
52
|
+
from mindspore.common.api import _convert_python_data
|
|
51
53
|
from mindspore.common.api import _process_dyn_args, _generate_dyn_compile_args
|
|
52
|
-
from mindspore.common.parameter import _Buffer, Parameter, ParameterTuple
|
|
54
|
+
from mindspore.common.parameter import _Buffer, Parameter, ParameterTuple, _is_parameter_generated
|
|
53
55
|
from mindspore.common.tensor import Tensor
|
|
54
|
-
from mindspore.ops.operations import Cast
|
|
55
56
|
from mindspore.ops.primitive import Primitive
|
|
56
57
|
from mindspore.ops.operations import _inner_ops as inner
|
|
57
58
|
from mindspore.parallel.shard import Shard
|
|
58
59
|
from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
|
|
59
60
|
from mindspore._check_jit_forbidden_api import jit_forbidden_register
|
|
60
|
-
from mindspore.common._decorator import deprecated
|
|
61
61
|
from mindspore.common._register_for_recompute import recompute_registry
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
__all__ = [
|
|
65
|
-
"register_cell_buffer_registration_hook",
|
|
66
|
-
]
|
|
62
|
+
from mindspore.common.jit_config import JitConfig
|
|
67
63
|
|
|
68
64
|
_global_buffer_registration_hooks: Dict[int, Callable] = OrderedDict()
|
|
69
65
|
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
|
|
@@ -102,7 +98,6 @@ def register_cell_buffer_registration_hook(hook: Callable[..., None],):
|
|
|
102
98
|
return handle
|
|
103
99
|
|
|
104
100
|
|
|
105
|
-
|
|
106
101
|
class Cell(Cell_):
|
|
107
102
|
"""
|
|
108
103
|
The basic building block of neural networks in MindSpore. The model or neural network layer should inherit this
|
|
@@ -160,51 +155,57 @@ class Cell(Cell_):
|
|
|
160
155
|
IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_create_time',
|
|
161
156
|
'_func_graph_flags', '_parameter_layout_dict', '_params_list', '_phase', '_bprop_debug',
|
|
162
157
|
'_forward_pre_hook', '_forward_hook', '_backward_pre_hook', '_backward_hook',
|
|
163
|
-
'_cell_backward_pre_hook', '_cell_backward_hook', '
|
|
164
|
-
'_attr_synced', 'pynative', 'requires_grad', 'cell_type',
|
|
165
|
-
'_parameters_forward_hook', '_parameters_backward_hook']
|
|
158
|
+
'_cell_backward_pre_hook', '_cell_backward_hook', '_param_prefix', 'requires_grad', 'cell_type']
|
|
166
159
|
total_instance_count = 0
|
|
167
160
|
_buffers: Dict[str, Optional[Tensor]]
|
|
168
|
-
|
|
161
|
+
global_cells = weakref.WeakKeyDictionary()
|
|
162
|
+
_no_auto_lazy_inline = True
|
|
163
|
+
|
|
164
|
+
def __new__(class_, *args, **kwargs):
|
|
165
|
+
# Use class_ to avoid name conflicts with input args and kwargs.
|
|
166
|
+
this = Cell_.__new__(class_, *args, **kwargs)
|
|
167
|
+
if Cell._no_auto_lazy_inline:
|
|
168
|
+
return this
|
|
169
|
+
|
|
170
|
+
Cell.global_cells[this] = (class_, args, kwargs)
|
|
171
|
+
return this
|
|
169
172
|
|
|
170
173
|
def __init__(self, auto_prefix=True, flags=None):
|
|
171
174
|
Cell_.__init__(self, self._cell_tag)
|
|
172
175
|
Cell.total_instance_count += 1
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
self._cells = OrderedDict()
|
|
176
|
+
super().__setattr__("_params", OrderedDict())
|
|
177
|
+
super().__setattr__("_cells", OrderedDict())
|
|
176
178
|
super().__setattr__("_buffers", {})
|
|
177
|
-
super().__setattr__("
|
|
178
|
-
super().__setattr__("
|
|
179
|
-
|
|
180
|
-
super().__setattr__("
|
|
181
|
-
super().__setattr__("
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
self.phase_cache = dict()
|
|
179
|
+
super().__setattr__("_params_list", OrderedDict())
|
|
180
|
+
super().__setattr__("_primitives", OrderedDict())
|
|
181
|
+
|
|
182
|
+
super().__setattr__("_lazy_non_persistent_buffers_set", None)
|
|
183
|
+
super().__setattr__("_lazy_state_dict_hooks", None)
|
|
184
|
+
super().__setattr__("_lazy_state_dict_pre_hooks", None)
|
|
185
|
+
super().__setattr__("_lazy_load_state_dict_pre_hooks", None)
|
|
186
|
+
super().__setattr__("_lazy_load_state_dict_post_hooks", None)
|
|
187
|
+
super().__setattr__("training", False)
|
|
188
|
+
super().__setattr__("requires_grad", False)
|
|
189
|
+
super().__setattr__("is_top_cell", False)
|
|
190
|
+
super().__setattr__("_param_prefix", '')
|
|
191
|
+
super().__setattr__("_auto_prefix", auto_prefix)
|
|
192
|
+
super().__setattr__("_scope", None)
|
|
193
|
+
super().__setattr__("_phase", 'train')
|
|
194
|
+
super().__setattr__("_parameter_layout_dict", None)
|
|
195
|
+
super().__setattr__("_parallel_parameter_name_list", None)
|
|
196
|
+
super().__setattr__("_parallel_parameter_merge_net_dict", None)
|
|
197
|
+
super().__setattr__("_create_time", int(time.time() * 1e9))
|
|
198
|
+
super().__setattr__("arguments_key", "")
|
|
199
|
+
super().__setattr__("_compile_cache", None)
|
|
200
|
+
super().__setattr__("_phase_cache", None)
|
|
200
201
|
cells_compile_cache[id(self)] = self.compile_cache
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
202
|
+
super().__setattr__("_id", 1)
|
|
203
|
+
super().__setattr__("_exist_objs", None)
|
|
204
|
+
super().__setattr__("_exist_names", None)
|
|
205
|
+
super().__setattr__("_recompute_cell", None)
|
|
206
|
+
super().__setattr__("mixed_precision_type", None)
|
|
207
|
+
super().__setattr__("_lazy_construct_sig", None)
|
|
208
|
+
super().__setattr__("_jit_graph_name", '')
|
|
208
209
|
init_pipeline()
|
|
209
210
|
|
|
210
211
|
# call gc to release GE session resources used by non-used cell objects
|
|
@@ -214,38 +215,33 @@ class Cell(Cell_):
|
|
|
214
215
|
|
|
215
216
|
if flags:
|
|
216
217
|
self.add_flags(**flags)
|
|
217
|
-
|
|
218
|
+
super().__setattr__("_bprop_debug", False)
|
|
218
219
|
|
|
219
220
|
# hook
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
self._is_check_and_refresh = False
|
|
243
|
-
self._amp_level = ""
|
|
244
|
-
self._init_flag = False
|
|
245
|
-
self._shard_fn = None
|
|
246
|
-
self.has_bprop = False
|
|
221
|
+
super().__setattr__("_lazy_forward_pre_hook", None)
|
|
222
|
+
super().__setattr__("_lazy_forward_hook", None)
|
|
223
|
+
super().__setattr__("_lazy_backward_pre_hook", None)
|
|
224
|
+
super().__setattr__("_lazy_backward_hook", None)
|
|
225
|
+
super().__setattr__("_lazy_forward_pre_hook_with_kwargs", None)
|
|
226
|
+
super().__setattr__("_lazy_forward_hook_with_kwargs", None)
|
|
227
|
+
super().__setattr__("_cell_backward_pre_hook", None)
|
|
228
|
+
super().__setattr__("_cell_backward_hook", None)
|
|
229
|
+
super().__setattr__("_is_recursion_hook", False)
|
|
230
|
+
|
|
231
|
+
super().__setattr__("cell_type", None)
|
|
232
|
+
super().__setattr__("_has_config_recompute", False)
|
|
233
|
+
super().__setattr__("_lazy_user_parameters", None)
|
|
234
|
+
super().__setattr__("_dynamic_shape_inputs", None)
|
|
235
|
+
super().__setattr__("_has_mutable_args_list", None)
|
|
236
|
+
super().__setattr__("_jit_config_dict", dict())
|
|
237
|
+
super().__setattr__("grad_ops_label", False)
|
|
238
|
+
super().__setattr__("_is_check_and_refresh", False)
|
|
239
|
+
super().__setattr__("_amp_level", "")
|
|
240
|
+
super().__setattr__("_init_flag", False)
|
|
241
|
+
super().__setattr__("_shard_fn", None)
|
|
242
|
+
super().__setattr__("has_bprop", False)
|
|
247
243
|
if hasattr(self, "bprop"):
|
|
248
|
-
|
|
244
|
+
super().__setattr__("has_bprop", True)
|
|
249
245
|
|
|
250
246
|
def __getstate__(self):
|
|
251
247
|
base = Cell_.__getstate__(self)
|
|
@@ -255,7 +251,6 @@ class Cell(Cell_):
|
|
|
255
251
|
base, dict_ = state
|
|
256
252
|
Cell_.__setstate__(self, base)
|
|
257
253
|
self.__dict__ = dict_
|
|
258
|
-
self._attr_synced = False
|
|
259
254
|
|
|
260
255
|
def __bool__(self):
|
|
261
256
|
return True
|
|
@@ -269,6 +264,112 @@ class Cell(Cell_):
|
|
|
269
264
|
def create_time(self):
|
|
270
265
|
return self._create_time
|
|
271
266
|
|
|
267
|
+
@property
|
|
268
|
+
def _non_persistent_buffers_set(self):
|
|
269
|
+
"""_non_persistent_buffers_set"""
|
|
270
|
+
if self._lazy_non_persistent_buffers_set is None:
|
|
271
|
+
super().__setattr__("_lazy_non_persistent_buffers_set", set())
|
|
272
|
+
return self._lazy_non_persistent_buffers_set
|
|
273
|
+
|
|
274
|
+
@property
|
|
275
|
+
def _state_dict_hooks(self):
|
|
276
|
+
"""_state_dict_hooks"""
|
|
277
|
+
if self._lazy_state_dict_hooks is None:
|
|
278
|
+
super().__setattr__("_lazy_state_dict_hooks", OrderedDict())
|
|
279
|
+
return self._lazy_state_dict_hooks
|
|
280
|
+
|
|
281
|
+
@property
|
|
282
|
+
def _state_dict_pre_hooks(self):
|
|
283
|
+
"""_state_dict_pre_hooks"""
|
|
284
|
+
if self._lazy_state_dict_pre_hooks is None:
|
|
285
|
+
super().__setattr__("_lazy_state_dict_pre_hooks", OrderedDict())
|
|
286
|
+
return self._lazy_state_dict_pre_hooks
|
|
287
|
+
|
|
288
|
+
@property
|
|
289
|
+
def _load_state_dict_pre_hooks(self):
|
|
290
|
+
"""_load_state_dict_pre_hooks"""
|
|
291
|
+
if self._lazy_load_state_dict_pre_hooks is None:
|
|
292
|
+
super().__setattr__("_lazy_load_state_dict_pre_hooks", OrderedDict())
|
|
293
|
+
return self._lazy_load_state_dict_pre_hooks
|
|
294
|
+
|
|
295
|
+
@property
|
|
296
|
+
def _load_state_dict_post_hooks(self):
|
|
297
|
+
"""_load_state_dict_post_hooks"""
|
|
298
|
+
if self._lazy_load_state_dict_post_hooks is None:
|
|
299
|
+
super().__setattr__("_lazy_load_state_dict_post_hooks", OrderedDict())
|
|
300
|
+
return self._lazy_load_state_dict_post_hooks
|
|
301
|
+
|
|
302
|
+
@property
|
|
303
|
+
def compile_cache(self):
|
|
304
|
+
"""compile_cache"""
|
|
305
|
+
if self._compile_cache is None:
|
|
306
|
+
super().__setattr__("_compile_cache", set())
|
|
307
|
+
return self._compile_cache
|
|
308
|
+
|
|
309
|
+
@property
|
|
310
|
+
def phase_cache(self):
|
|
311
|
+
"""phase_cache"""
|
|
312
|
+
if self._phase_cache is None:
|
|
313
|
+
super().__setattr__("_phase_cache", dict())
|
|
314
|
+
return self._phase_cache
|
|
315
|
+
|
|
316
|
+
@property
|
|
317
|
+
def _forward_pre_hook(self):
|
|
318
|
+
"""_forward_pre_hook"""
|
|
319
|
+
if self._lazy_forward_pre_hook is None:
|
|
320
|
+
super().__setattr__("_lazy_forward_pre_hook", OrderedDict())
|
|
321
|
+
return self._lazy_forward_pre_hook
|
|
322
|
+
|
|
323
|
+
@property
|
|
324
|
+
def _forward_hook(self):
|
|
325
|
+
"""_forward_hook"""
|
|
326
|
+
if self._lazy_forward_hook is None:
|
|
327
|
+
super().__setattr__("_lazy_forward_hook", OrderedDict())
|
|
328
|
+
return self._lazy_forward_hook
|
|
329
|
+
|
|
330
|
+
@property
|
|
331
|
+
def _backward_pre_hook(self):
|
|
332
|
+
"""_backward_pre_hook"""
|
|
333
|
+
if self._lazy_backward_pre_hook is None:
|
|
334
|
+
super().__setattr__("_lazy_backward_pre_hook", OrderedDict())
|
|
335
|
+
return self._lazy_backward_pre_hook
|
|
336
|
+
|
|
337
|
+
@property
|
|
338
|
+
def _backward_hook(self):
|
|
339
|
+
"""_backward_hook"""
|
|
340
|
+
if self._lazy_backward_hook is None:
|
|
341
|
+
super().__setattr__("_lazy_backward_hook", OrderedDict())
|
|
342
|
+
return self._lazy_backward_hook
|
|
343
|
+
|
|
344
|
+
@property
|
|
345
|
+
def _forward_pre_hook_with_kwargs(self):
|
|
346
|
+
"""_backward_hook"""
|
|
347
|
+
if self._lazy_forward_pre_hook_with_kwargs is None:
|
|
348
|
+
super().__setattr__("_lazy_forward_pre_hook_with_kwargs", OrderedDict())
|
|
349
|
+
return self._lazy_forward_pre_hook_with_kwargs
|
|
350
|
+
|
|
351
|
+
@property
|
|
352
|
+
def _forward_hook_with_kwargs(self):
|
|
353
|
+
"""_backward_hook"""
|
|
354
|
+
if self._lazy_forward_hook_with_kwargs is None:
|
|
355
|
+
super().__setattr__("_lazy_forward_hook_with_kwargs", OrderedDict())
|
|
356
|
+
return self._lazy_forward_hook_with_kwargs
|
|
357
|
+
|
|
358
|
+
@property
|
|
359
|
+
def _user_parameters(self):
|
|
360
|
+
"""_user_parameters"""
|
|
361
|
+
if self._lazy_user_parameters is None:
|
|
362
|
+
super().__setattr__("_lazy_user_parameters", [])
|
|
363
|
+
return self._lazy_user_parameters
|
|
364
|
+
|
|
365
|
+
@_user_parameters.setter
|
|
366
|
+
def _user_parameters(self, value):
|
|
367
|
+
"""_user_parameters"""
|
|
368
|
+
if not isinstance(value, list):
|
|
369
|
+
raise TypeError(f"For 'Cell', the property '_user_parameters' must be list type, "
|
|
370
|
+
f"but got type {type(value)}.")
|
|
371
|
+
self._lazy_user_parameters = value
|
|
372
|
+
|
|
272
373
|
@property
|
|
273
374
|
def cell_init_args(self):
|
|
274
375
|
return self._cell_init_args
|
|
@@ -279,15 +380,21 @@ class Cell(Cell_):
|
|
|
279
380
|
Get exist parameter names adding by tuple or list of parameter.
|
|
280
381
|
"""
|
|
281
382
|
if self._exist_names is None:
|
|
282
|
-
|
|
383
|
+
super().__setattr__("_exist_names", set(""))
|
|
283
384
|
return self._exist_names
|
|
284
385
|
|
|
285
386
|
@property
|
|
286
387
|
def exist_objs(self):
|
|
287
388
|
if self._exist_objs is None:
|
|
288
|
-
|
|
389
|
+
super().__setattr__("_exist_objs", set())
|
|
289
390
|
return self._exist_objs
|
|
290
391
|
|
|
392
|
+
@property
|
|
393
|
+
def _construct_sig(self):
|
|
394
|
+
if self._lazy_construct_sig is None:
|
|
395
|
+
super().__setattr__("_lazy_construct_sig", inspect.signature(self.construct))
|
|
396
|
+
return self._lazy_construct_sig
|
|
397
|
+
|
|
291
398
|
@property
|
|
292
399
|
def param_prefix(self):
|
|
293
400
|
"""
|
|
@@ -381,6 +488,8 @@ class Cell(Cell_):
|
|
|
381
488
|
`parameter_layout_dict` represents the tensor layout of a parameter, which is inferred by shard strategy and
|
|
382
489
|
distributed operator information.
|
|
383
490
|
"""
|
|
491
|
+
if self._parameter_layout_dict is None:
|
|
492
|
+
super().__setattr__("_parameter_layout_dict", {})
|
|
384
493
|
return self._parameter_layout_dict
|
|
385
494
|
|
|
386
495
|
@property
|
|
@@ -396,6 +505,8 @@ class Cell(Cell_):
|
|
|
396
505
|
|
|
397
506
|
@property
|
|
398
507
|
def parallel_parameter_name_list(self):
|
|
508
|
+
if self._parallel_parameter_name_list is None:
|
|
509
|
+
super().__setattr__("_parallel_parameter_name_list", ())
|
|
399
510
|
return self._parallel_parameter_name_list
|
|
400
511
|
|
|
401
512
|
@parallel_parameter_name_list.setter
|
|
@@ -450,6 +561,8 @@ class Cell(Cell_):
|
|
|
450
561
|
|
|
451
562
|
@property
|
|
452
563
|
def parallel_parameter_merge_net_dict(self):
|
|
564
|
+
if self._parallel_parameter_merge_net_dict is None:
|
|
565
|
+
super().__setattr__("_parallel_parameter_merge_net_dict", {})
|
|
453
566
|
return self._parallel_parameter_merge_net_dict
|
|
454
567
|
|
|
455
568
|
@parallel_parameter_merge_net_dict.setter
|
|
@@ -867,6 +980,7 @@ class Cell(Cell_):
|
|
|
867
980
|
if hasattr(self, "compile_cache") and self.compile_cache:
|
|
868
981
|
_cell_graph_executor.del_net_res(self, self.compile_cache)
|
|
869
982
|
Cell.total_instance_count -= 1
|
|
983
|
+
Cell.global_cells.pop(self, None)
|
|
870
984
|
|
|
871
985
|
def __delattr__(self, name):
|
|
872
986
|
if name in self._params:
|
|
@@ -879,47 +993,15 @@ class Cell(Cell_):
|
|
|
879
993
|
del self._params_list[name]
|
|
880
994
|
else:
|
|
881
995
|
object.__delattr__(self, name)
|
|
882
|
-
self._attr_synced = False
|
|
883
|
-
|
|
884
|
-
def _cast_mixed_precision_inputs(self, inputs, dst_type):
|
|
885
|
-
"""Cast input for mixed precision"""
|
|
886
|
-
res = list()
|
|
887
|
-
for item in inputs:
|
|
888
|
-
if isinstance(item, tuple):
|
|
889
|
-
res.append(self._cast_mixed_precision_inputs(item, dst_type))
|
|
890
|
-
elif isinstance(item, float):
|
|
891
|
-
res.append(self.cast(item, dst_type))
|
|
892
|
-
elif hasattr(item, "dtype") and item.dtype in \
|
|
893
|
-
{mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16} and item.dtype != dst_type:
|
|
894
|
-
res.append(self.cast(item, dst_type))
|
|
895
|
-
else:
|
|
896
|
-
res.append(item)
|
|
897
|
-
return tuple(res)
|
|
898
996
|
|
|
899
997
|
def cast_inputs(self, inputs, dst_type):
|
|
900
998
|
"""
|
|
901
999
|
Cast inputs to specified type.
|
|
902
1000
|
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
dst_type (mindspore.dtype): The specified data type.
|
|
906
|
-
|
|
907
|
-
returns:
|
|
908
|
-
tuple[Tensor], the result with destination data type.
|
|
1001
|
+
.. warning::
|
|
1002
|
+
This interface will be deprecated in future versions.
|
|
909
1003
|
"""
|
|
910
|
-
|
|
911
|
-
for item in inputs:
|
|
912
|
-
if isinstance(item, tuple):
|
|
913
|
-
res.append(self.cast_inputs(item, dst_type))
|
|
914
|
-
else:
|
|
915
|
-
res.append(self.cast(item, dst_type))
|
|
916
|
-
return tuple(res)
|
|
917
|
-
|
|
918
|
-
def _do_parameter_broadcast(self):
|
|
919
|
-
if context.get_auto_parallel_context("parallel_mode") == ParallelMode.DATA_PARALLEL:
|
|
920
|
-
if not self.parameter_broadcast_done:
|
|
921
|
-
_pynative_executor.parameter_broadcast(self, self.phase)
|
|
922
|
-
self.parameter_broadcast_done = True
|
|
1004
|
+
logger.warning(f"'cast_inputs' will be deprecated in future versions.")
|
|
923
1005
|
|
|
924
1006
|
def run_construct(self, cast_inputs, kwargs):
|
|
925
1007
|
"""
|
|
@@ -940,29 +1022,29 @@ class Cell(Cell_):
|
|
|
940
1022
|
output = self._run_construct(cast_inputs, kwargs)
|
|
941
1023
|
return output
|
|
942
1024
|
|
|
943
|
-
def _run_construct(self, *
|
|
1025
|
+
def _run_construct(self, *args, **kwargs):
|
|
944
1026
|
"""Run the construct function"""
|
|
945
1027
|
if self._forward_pre_hook:
|
|
946
|
-
|
|
1028
|
+
args, kwargs = self._run_forward_pre_hook(args, kwargs)
|
|
947
1029
|
|
|
948
1030
|
if self._shard_fn is not None:
|
|
949
|
-
output = self._shard_fn(*
|
|
1031
|
+
output = self._shard_fn(*args, **kwargs)
|
|
950
1032
|
elif _pynative_executor.requires_grad():
|
|
951
1033
|
if self._backward_hook:
|
|
952
|
-
output = self._backward_hook_construct(*
|
|
1034
|
+
output = self._backward_hook_construct(*args, **kwargs)
|
|
953
1035
|
elif self._recompute_cell is not None:
|
|
954
|
-
output = self._recompute_cell(*
|
|
1036
|
+
output = self._recompute_cell(*args, **kwargs)
|
|
955
1037
|
elif self.has_bprop:
|
|
956
|
-
output = self._call_custom_bprop(*
|
|
1038
|
+
output = self._call_custom_bprop(*args, **kwargs)
|
|
957
1039
|
else:
|
|
958
|
-
output = self.construct(*
|
|
1040
|
+
output = self.construct(*args, **kwargs)
|
|
959
1041
|
else:
|
|
960
|
-
output = self.construct(*
|
|
1042
|
+
output = self.construct(*args, **kwargs)
|
|
961
1043
|
|
|
962
1044
|
if self._forward_hook:
|
|
963
|
-
output = self._run_forward_hook(
|
|
1045
|
+
output = self._run_forward_hook(args, kwargs, output)
|
|
964
1046
|
|
|
965
|
-
if self._backward_pre_hook:
|
|
1047
|
+
if self._backward_pre_hook and _pynative_executor.requires_grad():
|
|
966
1048
|
output = self._run_backward_pre_hook(output)
|
|
967
1049
|
|
|
968
1050
|
return output
|
|
@@ -998,6 +1080,7 @@ class Cell(Cell_):
|
|
|
998
1080
|
f"{default_args} default argument, total {positional_args + default_args}, "
|
|
999
1081
|
f"but got {len(args)}.")
|
|
1000
1082
|
|
|
1083
|
+
# pylint: disable=E0203
|
|
1001
1084
|
def _hook_fn_registered(self):
|
|
1002
1085
|
'''Hook function in graph mode'''
|
|
1003
1086
|
# Check super().__init__() in graph mode.
|
|
@@ -1141,9 +1224,9 @@ class Cell(Cell_):
|
|
|
1141
1224
|
The parallel strategies of remaining operators are derived from the strategy specified by the input and output.
|
|
1142
1225
|
|
|
1143
1226
|
Note:
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
If the input contain Parameter, its strategy should be set in `in_strategy`.
|
|
1227
|
+
- It is valid only in semi auto parallel or auto parallel mode.
|
|
1228
|
+
In other parallel modes, strategies set here will be ignored.
|
|
1229
|
+
- If the input contain Parameter, its strategy should be set in `in_strategy`.
|
|
1147
1230
|
|
|
1148
1231
|
Args:
|
|
1149
1232
|
in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple. Tuple
|
|
@@ -1196,27 +1279,6 @@ class Cell(Cell_):
|
|
|
1196
1279
|
self._shard_fn = fn
|
|
1197
1280
|
return fn
|
|
1198
1281
|
|
|
1199
|
-
def auto_cast_inputs(self, inputs):
|
|
1200
|
-
"""
|
|
1201
|
-
Auto cast inputs in mixed precision scenarios.
|
|
1202
|
-
|
|
1203
|
-
Args:
|
|
1204
|
-
inputs (tuple): the inputs of construct.
|
|
1205
|
-
|
|
1206
|
-
Returns:
|
|
1207
|
-
Tuple, the inputs after data type cast.
|
|
1208
|
-
"""
|
|
1209
|
-
msg = f"'auto_cast_inputs' is deprecated from version 2.0 and will be removed in a future version."
|
|
1210
|
-
logger.warning(msg)
|
|
1211
|
-
cast_inputs = inputs
|
|
1212
|
-
mixed_type = self.get_mixed_precision_type()
|
|
1213
|
-
if mixed_type == MixedPrecisionType.FP16:
|
|
1214
|
-
cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float16)
|
|
1215
|
-
if mixed_type == MixedPrecisionType.FP32:
|
|
1216
|
-
cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float32)
|
|
1217
|
-
|
|
1218
|
-
return cast_inputs
|
|
1219
|
-
|
|
1220
1282
|
def _init_check(self):
|
|
1221
1283
|
for param in self.get_parameters(expand=False):
|
|
1222
1284
|
if param.has_init:
|
|
@@ -1229,10 +1291,16 @@ class Cell(Cell_):
|
|
|
1229
1291
|
self._is_check_and_refresh = True
|
|
1230
1292
|
|
|
1231
1293
|
def _predict(self, *args, **kwargs):
|
|
1294
|
+
'''Graph executor for predict'''
|
|
1232
1295
|
if not hasattr(self, "phase"):
|
|
1233
1296
|
return False, None
|
|
1234
1297
|
if (self.phase == "prefill" or self.phase == 'increment') and self.phase in self.phase_cache:
|
|
1235
|
-
new_args =
|
|
1298
|
+
new_args = _get_args_for_run(self, args, kwargs, self._has_mutable_args_list, True)
|
|
1299
|
+
if self.jit_config_dict:
|
|
1300
|
+
jit_config_dict = self.jit_config_dict
|
|
1301
|
+
else:
|
|
1302
|
+
jit_config_dict = JitConfig().jit_config_dict
|
|
1303
|
+
_cell_graph_executor._graph_executor.set_jit_config(jit_config_dict)
|
|
1236
1304
|
res = _cell_graph_executor._graph_executor(tuple(new_args), self.phase_cache[self.phase])
|
|
1237
1305
|
res = _convert_python_data(res)
|
|
1238
1306
|
return True, res
|
|
@@ -1242,7 +1310,7 @@ class Cell(Cell_):
|
|
|
1242
1310
|
# Run in Graph mode.
|
|
1243
1311
|
if context._get_mode() == context.GRAPH_MODE and os.getenv("MS_JIT") != '0':
|
|
1244
1312
|
if kwargs:
|
|
1245
|
-
bound_arguments = self.
|
|
1313
|
+
bound_arguments = self._construct_sig.bind(*args, **kwargs)
|
|
1246
1314
|
bound_arguments.apply_defaults()
|
|
1247
1315
|
args = bound_arguments.args
|
|
1248
1316
|
kwargs = bound_arguments.kwargs
|
|
@@ -1324,37 +1392,12 @@ class Cell(Cell_):
|
|
|
1324
1392
|
"""
|
|
1325
1393
|
with _no_grad():
|
|
1326
1394
|
output = self.construct(*args, **kwargs)
|
|
1327
|
-
_pynative_executor.call_custom_bprop(self, output, *args, **kwargs)
|
|
1328
|
-
return output
|
|
1395
|
+
return _pynative_executor.call_custom_bprop(self, output, *args, **kwargs)
|
|
1329
1396
|
|
|
1330
1397
|
def _add_attr(self, name, value):
|
|
1331
1398
|
if name and name[:2] != '__' and name not in Cell.IGNORE_LIST:
|
|
1332
1399
|
super(Cell, self)._add_attr(name, value)
|
|
1333
1400
|
|
|
1334
|
-
def _sync_attr_for_compile(self):
|
|
1335
|
-
"""Sync the attr to c++ object."""
|
|
1336
|
-
if self._attr_synced:
|
|
1337
|
-
return
|
|
1338
|
-
cells = self.__dict__.get('_cells')
|
|
1339
|
-
for key in cells:
|
|
1340
|
-
cell = cells[key]
|
|
1341
|
-
cell._sync_attr_for_compile()
|
|
1342
|
-
self._add_attr(key, cell)
|
|
1343
|
-
params = self.__dict__.get('_params')
|
|
1344
|
-
for key in params:
|
|
1345
|
-
if '.' in key:
|
|
1346
|
-
continue
|
|
1347
|
-
param = params[key]
|
|
1348
|
-
self._add_attr(key, param)
|
|
1349
|
-
params_list = self.__dict__.get('_params_list')
|
|
1350
|
-
for key in params_list:
|
|
1351
|
-
params_list_item = params_list[key]
|
|
1352
|
-
self._add_attr(key, params_list_item)
|
|
1353
|
-
for key in self.__dict__:
|
|
1354
|
-
value = self.__dict__[key]
|
|
1355
|
-
self._add_attr(key, value)
|
|
1356
|
-
self._attr_synced = True
|
|
1357
|
-
|
|
1358
1401
|
def _set_attr_for_param_or_param_tuple(self, name, value):
|
|
1359
1402
|
"""Set attr for param and tensor."""
|
|
1360
1403
|
if isinstance(value, Parameter):
|
|
@@ -1369,16 +1412,14 @@ class Cell(Cell_):
|
|
|
1369
1412
|
# If there are multiple identical objects, their names only check once.
|
|
1370
1413
|
continue
|
|
1371
1414
|
exist_objs.add(item)
|
|
1372
|
-
if item.name
|
|
1373
|
-
|
|
1374
|
-
"Please set a unique name for the parameter in ParameterTuple '{}'.".format(value))
|
|
1375
|
-
item.name = item.name + "$" + str(self._id)
|
|
1415
|
+
if _is_parameter_generated(item.name):
|
|
1416
|
+
item.name = "Parameter$" + str(self._id)
|
|
1376
1417
|
self._id += 1
|
|
1377
|
-
self.insert_param_to_cell(item.name, item, check_name_contain_dot=False)
|
|
1378
1418
|
if item.name in exist_names:
|
|
1379
1419
|
raise ValueError("The value {} , its name '{}' already exists. "
|
|
1380
1420
|
"Please set a unique name for the parameter.".format(value, item.name))
|
|
1381
1421
|
exist_names.add(item.name)
|
|
1422
|
+
self.insert_param_to_cell(item.name, item, check_name_contain_dot=False)
|
|
1382
1423
|
|
|
1383
1424
|
if context._get_mode() == context.PYNATIVE_MODE:
|
|
1384
1425
|
if name in self.__dict__:
|
|
@@ -1398,9 +1439,6 @@ class Cell(Cell_):
|
|
|
1398
1439
|
# If there are multiple identical objects, their names only check once.
|
|
1399
1440
|
continue
|
|
1400
1441
|
self.exist_objs.add(item)
|
|
1401
|
-
if item.name == PARAMETER_NAME_DEFAULT:
|
|
1402
|
-
item.name = item.name + "$" + str(self._id)
|
|
1403
|
-
self._id += 1
|
|
1404
1442
|
if item.name in self.exist_names:
|
|
1405
1443
|
raise ValueError(f"The value {value} , its name '{item.name}' already exists. "
|
|
1406
1444
|
"Please set a unique name for the parameter.")
|
|
@@ -1513,24 +1551,6 @@ class Cell(Cell_):
|
|
|
1513
1551
|
main_str += ")"
|
|
1514
1552
|
return main_str
|
|
1515
1553
|
|
|
1516
|
-
def load_parameter_slice(self, params):
|
|
1517
|
-
"""
|
|
1518
|
-
Replace parameters with sliced tensors by parallel strategies.
|
|
1519
|
-
|
|
1520
|
-
Note:
|
|
1521
|
-
This interface is deprecated.
|
|
1522
|
-
"""
|
|
1523
|
-
logger.warning("'load_parameter_slice' function is deprecated.")
|
|
1524
|
-
|
|
1525
|
-
def set_parallel_input_with_inputs(self, *inputs):
|
|
1526
|
-
"""
|
|
1527
|
-
Slice inputs tensors by parallel strategies.
|
|
1528
|
-
|
|
1529
|
-
Note:
|
|
1530
|
-
This interface is deprecated.
|
|
1531
|
-
"""
|
|
1532
|
-
logger.warning("'set_parallel_input_with_inputs' function is deprecated.")
|
|
1533
|
-
|
|
1534
1554
|
def set_inputs(self, *inputs, **kwargs):
|
|
1535
1555
|
"""
|
|
1536
1556
|
Save set inputs for computation graph. The number of inputs should be the same with that of the datasets. When
|
|
@@ -1665,7 +1685,6 @@ class Cell(Cell_):
|
|
|
1665
1685
|
_cell_graph_executor._graph_executor.check_argument_consistency(compile_args, args, "set_inputs")
|
|
1666
1686
|
self._check_parameter_consistency(compile_args, args)
|
|
1667
1687
|
Validator.check_symbolic_shape(compile_args, args)
|
|
1668
|
-
self.saved_dynamic_shape = compile_args
|
|
1669
1688
|
return compile_args
|
|
1670
1689
|
return args
|
|
1671
1690
|
|
|
@@ -1678,8 +1697,9 @@ class Cell(Cell_):
|
|
|
1678
1697
|
kwargs (dict): Kwargs of the Cell object.
|
|
1679
1698
|
"""
|
|
1680
1699
|
_init_auto_parallel_context(self)
|
|
1681
|
-
|
|
1682
|
-
|
|
1700
|
+
compile_args = self._get_compile_args(args)
|
|
1701
|
+
self._has_mutable_args_list = _get_mutable_flags(compile_args)
|
|
1702
|
+
_cell_graph_executor.compile(self, *compile_args, phase=self.phase,
|
|
1683
1703
|
jit_config_dict=self._jit_config_dict, **kwargs)
|
|
1684
1704
|
_clear_auto_parallel_context(self)
|
|
1685
1705
|
|
|
@@ -1698,25 +1718,14 @@ class Cell(Cell_):
|
|
|
1698
1718
|
Object, the result of executing.
|
|
1699
1719
|
"""
|
|
1700
1720
|
self.compile(*args, **kwargs)
|
|
1701
|
-
self.
|
|
1702
|
-
|
|
1721
|
+
new_args = _get_args_for_run(self, args, kwargs, self._has_mutable_args_list, False)
|
|
1722
|
+
if self.jit_config_dict:
|
|
1723
|
+
jit_config_dict = self.jit_config_dict
|
|
1724
|
+
else:
|
|
1725
|
+
jit_config_dict = JitConfig().jit_config_dict
|
|
1726
|
+
_cell_graph_executor._graph_executor.set_jit_config(jit_config_dict)
|
|
1703
1727
|
return _cell_graph_executor(self, *new_args, phase=self.phase)
|
|
1704
1728
|
|
|
1705
|
-
def auto_parallel_compile_and_run(self):
|
|
1706
|
-
"""
|
|
1707
|
-
Whether or not to execute compile and run in 'AUTO_PARALLEL' or 'SEMI_AUTO_PARALLEL' mode.
|
|
1708
|
-
|
|
1709
|
-
Note:
|
|
1710
|
-
This interface is deprecated.
|
|
1711
|
-
"""
|
|
1712
|
-
logger.warning("'auto_parallel_compile_and_run' function is deprecated.")
|
|
1713
|
-
|
|
1714
|
-
def exec_checkpoint_graph(self):
|
|
1715
|
-
"""Executes GE saving checkpoint graph operation."""
|
|
1716
|
-
logger.warning("'exec_checkpoint_graph' function is deprecated.")
|
|
1717
|
-
self.add_flags(ge_sync_data=True)
|
|
1718
|
-
_cell_graph_executor(self, phase='save')
|
|
1719
|
-
|
|
1720
1729
|
def insert_param_to_cell(self, param_name, param, check_name_contain_dot=True):
|
|
1721
1730
|
"""
|
|
1722
1731
|
Adds a parameter to the current cell.
|
|
@@ -1762,35 +1771,10 @@ class Cell(Cell_):
|
|
|
1762
1771
|
if not isinstance(param, Parameter) and param is not None:
|
|
1763
1772
|
raise TypeError(f"For 'insert_param_to_cell', the argument 'param' must be 'Parameter' if not None, "
|
|
1764
1773
|
f"but got {type(param)}.")
|
|
1765
|
-
if isinstance(param, Parameter) and param.name
|
|
1774
|
+
if isinstance(param, Parameter) and _is_parameter_generated(param.name):
|
|
1766
1775
|
param.name = param_name
|
|
1767
1776
|
self._params[param_name] = param
|
|
1768
1777
|
|
|
1769
|
-
def cast_param(self, param):
|
|
1770
|
-
"""
|
|
1771
|
-
Cast parameter according to auto mix precision level in pynative mode.
|
|
1772
|
-
|
|
1773
|
-
This interface is currently used in the case of auto mix precision and usually needs not to be used explicitly.
|
|
1774
|
-
|
|
1775
|
-
Args:
|
|
1776
|
-
param (Parameter): Parameters, the type of which should be cast.
|
|
1777
|
-
|
|
1778
|
-
Returns:
|
|
1779
|
-
Parameter, the input parameter with type automatically cast.
|
|
1780
|
-
"""
|
|
1781
|
-
msg = f"'cast_param' is deprecated from version 2.0 and will be removed in a future version."
|
|
1782
|
-
logger.warning(msg)
|
|
1783
|
-
mixed_type = self.get_mixed_precision_type()
|
|
1784
|
-
if mixed_type != MixedPrecisionType.NOTSET:
|
|
1785
|
-
if mixed_type == MixedPrecisionType.FP32:
|
|
1786
|
-
param.set_cast_dtype(mstype.float32)
|
|
1787
|
-
elif mixed_type == MixedPrecisionType.FP16:
|
|
1788
|
-
param.set_cast_dtype(mstype.float16)
|
|
1789
|
-
elif hasattr(param, "set_cast_dtype"):
|
|
1790
|
-
# retest dtype
|
|
1791
|
-
param.set_cast_dtype()
|
|
1792
|
-
return param
|
|
1793
|
-
|
|
1794
1778
|
def insert_child_to_cell(self, child_name, child_cell):
|
|
1795
1779
|
"""
|
|
1796
1780
|
Adds a child cell to the current cell with a given name.
|
|
@@ -1850,27 +1834,10 @@ class Cell(Cell_):
|
|
|
1850
1834
|
"""
|
|
1851
1835
|
Remove the redundant parameters.
|
|
1852
1836
|
|
|
1853
|
-
|
|
1837
|
+
.. warning::
|
|
1838
|
+
This interface will be deprecated in future versions.
|
|
1854
1839
|
"""
|
|
1855
|
-
|
|
1856
|
-
for _, cell in cells:
|
|
1857
|
-
params = cell._params.items()
|
|
1858
|
-
for param_name, param in list(params):
|
|
1859
|
-
if param.name not in self.parallel_parameter_name_list:
|
|
1860
|
-
cell._params.pop(param_name)
|
|
1861
|
-
logger.info("remove the redundant parameter: %s", param.name)
|
|
1862
|
-
continue
|
|
1863
|
-
cell_dict = cell.__dict__
|
|
1864
|
-
for key in cell_dict:
|
|
1865
|
-
if isinstance(cell_dict[key], ParameterTuple):
|
|
1866
|
-
param_tuple = cell_dict[key]
|
|
1867
|
-
new_param_tuple = []
|
|
1868
|
-
for param in param_tuple:
|
|
1869
|
-
if param.name not in self.parallel_parameter_name_list:
|
|
1870
|
-
logger.info("remove the redundant parameter: %s in ParameterTuple", param.name)
|
|
1871
|
-
continue
|
|
1872
|
-
new_param_tuple.append(param)
|
|
1873
|
-
cell.__dict__[key] = ParameterTuple(new_param_tuple)
|
|
1840
|
+
logger.warning(f"'remove_redundant_parameters' will be deprecated in future versions.")
|
|
1874
1841
|
|
|
1875
1842
|
def _get_cell_parallel_mode(self):
|
|
1876
1843
|
"""Determine whether the current cell is in parallel mode."""
|
|
@@ -1926,16 +1893,13 @@ class Cell(Cell_):
|
|
|
1926
1893
|
# replace all original usage.
|
|
1927
1894
|
cells = self.cells_and_names()
|
|
1928
1895
|
is_parallel_mode = self._get_cell_parallel_mode()
|
|
1929
|
-
is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
|
|
1930
1896
|
|
|
1931
1897
|
for _, cell in cells:
|
|
1932
1898
|
params = cell._params.items()
|
|
1933
1899
|
for param_name, param in params:
|
|
1934
|
-
not_sliced = not param.sliced
|
|
1935
|
-
judgment = not_sliced
|
|
1936
1900
|
if param.param_info.is_pipeline_shared_param:
|
|
1937
1901
|
continue
|
|
1938
|
-
if
|
|
1902
|
+
if is_parallel_mode and not param.sliced:
|
|
1939
1903
|
continue
|
|
1940
1904
|
if not auto_parallel_mode:
|
|
1941
1905
|
cell._params[param_name] = _updata(param)
|
|
@@ -1948,11 +1912,9 @@ class Cell(Cell_):
|
|
|
1948
1912
|
param_tuple = cell_dict[key]
|
|
1949
1913
|
new_param_tuple = []
|
|
1950
1914
|
for param in param_tuple:
|
|
1951
|
-
not_sliced = not param.sliced
|
|
1952
|
-
judgment = not_sliced
|
|
1953
1915
|
if param.param_info.is_pipeline_shared_param:
|
|
1954
1916
|
continue
|
|
1955
|
-
if
|
|
1917
|
+
if is_parallel_mode and not param.sliced:
|
|
1956
1918
|
continue
|
|
1957
1919
|
if not auto_parallel_mode:
|
|
1958
1920
|
new_param_tuple.append(_updata(param))
|
|
@@ -2591,15 +2553,6 @@ class Cell(Cell_):
|
|
|
2591
2553
|
self.add_flags_recursive(broadcast_flag=mode)
|
|
2592
2554
|
return self
|
|
2593
2555
|
|
|
2594
|
-
def set_auto_parallel(self):
|
|
2595
|
-
"""
|
|
2596
|
-
Set the cell to auto parallel mode.
|
|
2597
|
-
|
|
2598
|
-
Note:
|
|
2599
|
-
This interface is deprecated.
|
|
2600
|
-
"""
|
|
2601
|
-
logger.warning("'set_auto_parallel' function is deprecated.")
|
|
2602
|
-
|
|
2603
2556
|
def set_jit_config(self, jit_config):
|
|
2604
2557
|
"""
|
|
2605
2558
|
Set jit config for cell.
|
|
@@ -2645,25 +2598,38 @@ class Cell(Cell_):
|
|
|
2645
2598
|
raise ValueError(f"Negative 'fusion_size' {fusion_size} is invalid.")
|
|
2646
2599
|
Tensor._flatten_tensors(self.trainable_params(), fusion_size) # pylint: disable=W0212
|
|
2647
2600
|
|
|
2648
|
-
def register_forward_pre_hook(self, hook_fn):
|
|
2601
|
+
def register_forward_pre_hook(self, hook_fn, with_kwargs=False):
|
|
2649
2602
|
"""
|
|
2650
2603
|
Register forward pre hook function for Cell object.
|
|
2651
2604
|
|
|
2605
|
+
The hook will be called before :func:`mindspore.nn.Cell.construct` is invoked.
|
|
2606
|
+
|
|
2607
|
+
The hook function should be one of the following signatures:
|
|
2608
|
+
|
|
2609
|
+
- `hook_fn(cell, args) -> None or new_args` , when `with_kwargs` is ``Flase`` .
|
|
2610
|
+
- `hook_fn(cell, args, kwargs) -> None or (new_args, new_kwargs)` , when `with_kwargs` is ``True`` .
|
|
2611
|
+
|
|
2612
|
+
where:
|
|
2613
|
+
|
|
2614
|
+
- `cell` (Cell): Cell object on which the hook is registered.
|
|
2615
|
+
- `args` (tuple): Positional arguments passed to the `construct` function.
|
|
2616
|
+
- `kwargs` (dict): Keyword arguments passed to the `construct` function. Only passed to `hook_fn` when
|
|
2617
|
+
`with_kwargs` is ``True`` .
|
|
2618
|
+
|
|
2652
2619
|
Note:
|
|
2653
|
-
- The
|
|
2654
|
-
-
|
|
2655
|
-
|
|
2656
|
-
|
|
2657
|
-
|
|
2658
|
-
|
|
2659
|
-
|
|
2660
|
-
|
|
2661
|
-
`construct` function of Cell object. In the pynative mode, if the `register_forward_pre_hook` function is
|
|
2662
|
-
called in the `construct` function of the Cell object, a hook function will be added at each run time of
|
|
2663
|
-
Cell object.
|
|
2620
|
+
- The feature does not take effect in graph mode or in PyNative mode with functions decorated by jit.
|
|
2621
|
+
- The `hook_fn` can modify the forward inputs by returning new inputs. If `with_kwargs` is ``Flase`` , a
|
|
2622
|
+
single value (whick will be wrapped into a tuple unless already a tuple) or a tuple of args should be
|
|
2623
|
+
returned. If `with_kwargs` is ``True`` , both `args` and `kwargs` should be returned.
|
|
2624
|
+
- In order to prevent running failed when switching to graph mode, it is not recommended to call it in the
|
|
2625
|
+
`construct` function of Cell object.
|
|
2626
|
+
- In the pynative mode, if this method is called inside the `construct` function of the Cell object, a
|
|
2627
|
+
`hook_fn` will be added at each run time of Cell object.
|
|
2664
2628
|
|
|
2665
2629
|
Args:
|
|
2666
2630
|
hook_fn (function): Python function. Forward pre hook function.
|
|
2631
|
+
with_kwargs (bool, optional): Specifies whether hook_fn will be passed the kwargs given to the `construct`
|
|
2632
|
+
function. Default: ``False`` .
|
|
2667
2633
|
|
|
2668
2634
|
Returns:
|
|
2669
2635
|
A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
|
|
@@ -2705,60 +2671,66 @@ class Cell(Cell_):
|
|
|
2705
2671
|
if context._get_mode() == context.GRAPH_MODE:
|
|
2706
2672
|
return HookHandle()
|
|
2707
2673
|
check_hook_fn(hook_fn)
|
|
2708
|
-
handle = HookHandle(self._forward_pre_hook)
|
|
2674
|
+
handle = HookHandle(self._forward_pre_hook, extra_dict=self._forward_pre_hook_with_kwargs)
|
|
2709
2675
|
self._forward_pre_hook[handle.handle_id] = hook_fn
|
|
2676
|
+
if with_kwargs:
|
|
2677
|
+
self._forward_pre_hook_with_kwargs[handle.handle_id] = True
|
|
2710
2678
|
return handle
|
|
2711
2679
|
|
|
2712
|
-
def _run_forward_pre_hook(self,
|
|
2680
|
+
def _run_forward_pre_hook(self, args, kwargs):
|
|
2713
2681
|
"""
|
|
2714
2682
|
Running forward pre hook function registered on Cell object.
|
|
2683
|
+
"""
|
|
2684
|
+
for hook_id, hook_fn in self._forward_pre_hook.items():
|
|
2685
|
+
if hook_id in self._forward_pre_hook_with_kwargs:
|
|
2686
|
+
ret = hook_fn(self, args, kwargs)
|
|
2687
|
+
if ret is not None:
|
|
2688
|
+
if isinstance(ret, tuple) and len(ret) == 2:
|
|
2689
|
+
args, kwargs = ret
|
|
2690
|
+
else:
|
|
2691
|
+
raise RuntimeError(
|
|
2692
|
+
"forward pre hook with kwargs must return None or a tuple of (new_args, new_kwargs), "
|
|
2693
|
+
f"but got {ret}"
|
|
2694
|
+
)
|
|
2695
|
+
else:
|
|
2696
|
+
ret = hook_fn(self, args)
|
|
2697
|
+
if ret is not None:
|
|
2698
|
+
if not isinstance(ret, tuple):
|
|
2699
|
+
ret = (ret,)
|
|
2700
|
+
args = ret
|
|
2701
|
+
return args, kwargs
|
|
2715
2702
|
|
|
2716
|
-
|
|
2717
|
-
|
|
2703
|
+
def register_forward_hook(self, hook_fn, with_kwargs=False):
|
|
2704
|
+
"""
|
|
2705
|
+
Register forward hook function for Cell object.
|
|
2718
2706
|
|
|
2719
|
-
|
|
2720
|
-
- **outputs** - New input objects or none.
|
|
2707
|
+
This hook will be called after :func:`mindspore.nn.Cell.construct` has computed an output.
|
|
2721
2708
|
|
|
2722
|
-
|
|
2723
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
2724
|
-
"""
|
|
2725
|
-
forward_pre_hook_inputs = inputs
|
|
2726
|
-
for fn in self._forward_pre_hook.values():
|
|
2727
|
-
ret = fn(self, forward_pre_hook_inputs)
|
|
2728
|
-
if ret is not None:
|
|
2729
|
-
if not isinstance(ret, tuple):
|
|
2730
|
-
forward_pre_hook_inputs = (ret,)
|
|
2731
|
-
else:
|
|
2732
|
-
forward_pre_hook_inputs = ret
|
|
2709
|
+
The hook function should be one of the following signatures:
|
|
2733
2710
|
|
|
2734
|
-
|
|
2735
|
-
|
|
2736
|
-
forward_pre_hook_inputs = (forward_pre_hook_inputs,)
|
|
2737
|
-
if len(forward_pre_hook_inputs) != len(inputs):
|
|
2738
|
-
raise TypeError(
|
|
2739
|
-
"The forward pre hook return value size is {} not equal to input size {}".format(
|
|
2740
|
-
len(forward_pre_hook_inputs), len(inputs)))
|
|
2741
|
-
return forward_pre_hook_inputs
|
|
2711
|
+
- `hook_fn(cell, args, output) -> None or new_output` , when `with_kwargs` is ``False`` .
|
|
2712
|
+
- `hook_fn(cell, args, kwargs, output) -> None or new_output` , when `with_kwargs` is ``True`` .
|
|
2742
2713
|
|
|
2743
|
-
|
|
2744
|
-
|
|
2745
|
-
|
|
2714
|
+
where:
|
|
2715
|
+
|
|
2716
|
+
- `cell` (Cell): Cell object on which the hook is registered.
|
|
2717
|
+
- `args` (tuple): Positional arguments passed to the `construct` function.
|
|
2718
|
+
- `kwargs` (dict): Keyword arguments passed to the `construct` function. Only passed to `hook_fn` when
|
|
2719
|
+
`with_kwargs` is ``True`` .
|
|
2720
|
+
- `output`: Output generated by the `construct` function.
|
|
2746
2721
|
|
|
2747
2722
|
Note:
|
|
2748
|
-
- The
|
|
2749
|
-
-
|
|
2750
|
-
|
|
2751
|
-
|
|
2752
|
-
|
|
2753
|
-
|
|
2754
|
-
hook_fn(cell, inputs, output) -> new output object or none.
|
|
2755
|
-
- In order to prevent running failed when switching to graph mode, it is not recommended to write it in the
|
|
2756
|
-
`construct` function of Cell object. In the pynative mode, if the `register_forward_hook` function is
|
|
2757
|
-
called in the `construct` function of the Cell object, a hook function will be added at each run time of
|
|
2758
|
-
Cell object.
|
|
2723
|
+
- The feature does not take effect in graph mode or in PyNative mode with functions decorated by jit.
|
|
2724
|
+
- The `hook_fn` can modify the forward outputs by returning new outputs.
|
|
2725
|
+
- In order to prevent running failed when switching to graph mode, it is not recommended to call it in the
|
|
2726
|
+
`construct` function of Cell object.
|
|
2727
|
+
- In the pynative mode, if this method is called inside the `construct` function of the Cell object, a
|
|
2728
|
+
`hook_fn` will be added at each run time of Cell object.
|
|
2759
2729
|
|
|
2760
2730
|
Args:
|
|
2761
2731
|
hook_fn (function): Python function. Forward hook function.
|
|
2732
|
+
with_kwargs (bool, optional): Specifies whether hook_fn will be passed the kwargs given to the `construct`
|
|
2733
|
+
function. Default: ``False`` .
|
|
2762
2734
|
|
|
2763
2735
|
Returns:
|
|
2764
2736
|
A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
|
|
@@ -2804,38 +2776,24 @@ class Cell(Cell_):
|
|
|
2804
2776
|
if context._get_mode() == context.GRAPH_MODE:
|
|
2805
2777
|
return HookHandle()
|
|
2806
2778
|
check_hook_fn(hook_fn)
|
|
2807
|
-
handle = HookHandle(self._forward_hook)
|
|
2779
|
+
handle = HookHandle(self._forward_hook, extra_dict=self._forward_hook_with_kwargs)
|
|
2808
2780
|
self._forward_hook[handle.handle_id] = hook_fn
|
|
2781
|
+
if with_kwargs:
|
|
2782
|
+
self._forward_hook_with_kwargs[handle.handle_id] = True
|
|
2809
2783
|
return handle
|
|
2810
2784
|
|
|
2811
|
-
def _run_forward_hook(self,
|
|
2785
|
+
def _run_forward_hook(self, args, kwargs, output):
|
|
2812
2786
|
"""
|
|
2813
2787
|
Running forward hook function registered on Cell object.
|
|
2814
|
-
|
|
2815
|
-
Args:
|
|
2816
|
-
inputs: The input objects of Cell object.
|
|
2817
|
-
output: The output object of Cell object.
|
|
2818
|
-
|
|
2819
|
-
Returns:
|
|
2820
|
-
- **output** - New output object or none.
|
|
2821
|
-
|
|
2822
|
-
Supported Platforms:
|
|
2823
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
2824
2788
|
"""
|
|
2825
|
-
|
|
2826
|
-
|
|
2827
|
-
|
|
2789
|
+
for hook_id, hook_fn in self._forward_hook.items():
|
|
2790
|
+
if hook_id in self._forward_hook_with_kwargs:
|
|
2791
|
+
ret = hook_fn(self, args, kwargs, output)
|
|
2792
|
+
else:
|
|
2793
|
+
ret = hook_fn(self, args, output)
|
|
2828
2794
|
if ret is not None:
|
|
2829
|
-
|
|
2830
|
-
|
|
2831
|
-
if isinstance(output, tuple):
|
|
2832
|
-
if not isinstance(forward_hook_output, tuple):
|
|
2833
|
-
forward_hook_output = (forward_hook_output,)
|
|
2834
|
-
if len(forward_hook_output) != len(output):
|
|
2835
|
-
raise TypeError(
|
|
2836
|
-
"The forward hook return value size is {} not equal to output size {}".format(
|
|
2837
|
-
len(forward_hook_output), len(output)))
|
|
2838
|
-
return forward_hook_output
|
|
2795
|
+
output = ret
|
|
2796
|
+
return output
|
|
2839
2797
|
|
|
2840
2798
|
def register_backward_pre_hook(self, hook_fn):
|
|
2841
2799
|
"""
|
|
@@ -3116,7 +3074,6 @@ class Cell(Cell_):
|
|
|
3116
3074
|
OrderedDict([('param_a', Parameter (name=param_a, shape=(3,), dtype=Int64, requires_grad=True)), \
|
|
3117
3075
|
('buffer_a', Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]))])
|
|
3118
3076
|
"""
|
|
3119
|
-
# TODO: Remove `args` and the parsing logic when BC allows.
|
|
3120
3077
|
if args:
|
|
3121
3078
|
# DeprecationWarning is ignored by default
|
|
3122
3079
|
warnings.warn(
|
|
@@ -3169,7 +3126,7 @@ class Cell(Cell_):
|
|
|
3169
3126
|
|
|
3170
3127
|
It should have the following signature:
|
|
3171
3128
|
|
|
3172
|
-
hook(cell, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None
|
|
3129
|
+
hook(cell, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None
|
|
3173
3130
|
|
|
3174
3131
|
Args:
|
|
3175
3132
|
hook (Callable): The hook function before `load_state_dict` is called.
|
|
@@ -3584,12 +3541,6 @@ class Cell(Cell_):
|
|
|
3584
3541
|
for param in params:
|
|
3585
3542
|
param.set_param_ps(init_in_server)
|
|
3586
3543
|
|
|
3587
|
-
@deprecated("1.8", "set_param_fl")
|
|
3588
|
-
def set_param_fl(self, push_to_server=False, pull_from_server=False, requires_aggr=True):
|
|
3589
|
-
params = self.parameters_and_names()
|
|
3590
|
-
for param in params:
|
|
3591
|
-
param[1].set_param_fl(push_to_server, pull_from_server, requires_aggr)
|
|
3592
|
-
|
|
3593
3544
|
def set_comm_fusion(self, fusion_type, recurse=True):
|
|
3594
3545
|
"""
|
|
3595
3546
|
Set `comm_fusion` for all the parameters in this cell. Please refer to the description of
|
|
@@ -3698,7 +3649,7 @@ class Cell(Cell_):
|
|
|
3698
3649
|
self._recompute()
|
|
3699
3650
|
if 'mp_comm_recompute' in kwargs.keys():
|
|
3700
3651
|
self._mp_comm_recompute(kwargs.get('mp_comm_recompute', False))
|
|
3701
|
-
if 'parallel_optimizer_comm_recompute' in kwargs
|
|
3652
|
+
if 'parallel_optimizer_comm_recompute' in kwargs:
|
|
3702
3653
|
if kwargs.get('parallel_optimizer_comm_recompute', False):
|
|
3703
3654
|
logger.warning("Currently, the communication operator allgathers introduced by optimizer shard "
|
|
3704
3655
|
"is replaced with zero3.")
|
|
@@ -3711,38 +3662,6 @@ class Cell(Cell_):
|
|
|
3711
3662
|
"the key kwargs must be 'mp_comm_recompute', "
|
|
3712
3663
|
"'parallel_optimizer_comm_recompute', 'recompute_slice_activation'" % key)
|
|
3713
3664
|
|
|
3714
|
-
@deprecated("2.3", "infer_param_pipeline_stage")
|
|
3715
|
-
def infer_param_pipeline_stage(self):
|
|
3716
|
-
"""
|
|
3717
|
-
Infer pipeline stages of all parameters in the cell.
|
|
3718
|
-
|
|
3719
|
-
Note:
|
|
3720
|
-
- The interface is deprecated from version 2.3 and will be removed in a future version.
|
|
3721
|
-
|
|
3722
|
-
Returns:
|
|
3723
|
-
The params belong to current stage in pipeline parallel.
|
|
3724
|
-
|
|
3725
|
-
Raises:
|
|
3726
|
-
RuntimeError: If there is a parameter does not belong to any stage.
|
|
3727
|
-
"""
|
|
3728
|
-
from mindspore.parallel._utils import _get_global_rank, _get_device_num
|
|
3729
|
-
logger.warning(f"This interface may be deleted in the future.")
|
|
3730
|
-
stage_num = context.get_auto_parallel_context("pipeline_stages")
|
|
3731
|
-
device_num = _get_device_num()
|
|
3732
|
-
rank_id = _get_global_rank()
|
|
3733
|
-
per_stage_devices = device_num // stage_num
|
|
3734
|
-
current_stage = rank_id // per_stage_devices
|
|
3735
|
-
params = []
|
|
3736
|
-
for param in self.trainable_params():
|
|
3737
|
-
if not param._pipeline_stage_list: # pylint: disable=W0212
|
|
3738
|
-
raise RuntimeError("For 'infer_param_pipeline_stage', the parameter {} does not belong to any stage, "
|
|
3739
|
-
"please check whether the cell where the param locates has been set "
|
|
3740
|
-
"'pipeline_stage'. Otherwise, the parameter should use 'add_pipeline_stage' "
|
|
3741
|
-
"to add its stage information".format(param.name))
|
|
3742
|
-
if current_stage in param._pipeline_stage_list:
|
|
3743
|
-
params.append(param)
|
|
3744
|
-
return params
|
|
3745
|
-
|
|
3746
3665
|
def place(self, role, rank_id):
|
|
3747
3666
|
"""
|
|
3748
3667
|
Set the label for all operators in this cell.
|
|
@@ -3772,19 +3691,6 @@ class Cell(Cell_):
|
|
|
3772
3691
|
for op in all_ops:
|
|
3773
3692
|
op.place(role, rank_id)
|
|
3774
3693
|
|
|
3775
|
-
def _mixed_precision_cast(self, inputs):
|
|
3776
|
-
mixed_type = self.get_mixed_precision_type()
|
|
3777
|
-
if mixed_type == MixedPrecisionType.NOTSET:
|
|
3778
|
-
return inputs
|
|
3779
|
-
if mixed_type == MixedPrecisionType.FP16:
|
|
3780
|
-
cast_type = mstype.float16
|
|
3781
|
-
elif mixed_type == MixedPrecisionType.BF16:
|
|
3782
|
-
cast_type = mstype.bfloat16
|
|
3783
|
-
else:
|
|
3784
|
-
cast_type = mstype.float32
|
|
3785
|
-
cast_inputs = self._cast_mixed_precision_inputs(inputs, cast_type)
|
|
3786
|
-
return cast_inputs
|
|
3787
|
-
|
|
3788
3694
|
def _get_attr_from_cell(self, network):
|
|
3789
3695
|
if not isinstance(network, Cell):
|
|
3790
3696
|
return
|
|
@@ -3793,91 +3699,11 @@ class Cell(Cell_):
|
|
|
3793
3699
|
if hasattr(network, "_amp_level"):
|
|
3794
3700
|
self._amp_level = getattr(network, "_amp_level")
|
|
3795
3701
|
|
|
3796
|
-
def
|
|
3702
|
+
def _set_jit_graph_name(self, key):
|
|
3797
3703
|
"""
|
|
3798
|
-
|
|
3799
|
-
|
|
3800
|
-
|
|
3801
|
-
This is an experimental prototype that is subject to change and/or deletion.
|
|
3802
|
-
|
|
3803
|
-
Note:
|
|
3804
|
-
- The `_register_parameters_hook(forward_hook, backward_hook)` only work in graph mode
|
|
3805
|
-
- The `forward_hook` must be defined as the following code.
|
|
3806
|
-
`parameters`: the tuple of the trainble parameters of the Cell, each element in the tuple shuould be
|
|
3807
|
-
in the format of `(param_name, Parameter)`.
|
|
3808
|
-
- The `forward_hook` should have the following signature:
|
|
3809
|
-
forward_hook(parameters) -> None.
|
|
3810
|
-
- The `backward_hook` must be defined as the following code.
|
|
3811
|
-
`gradients`: the tuple of the gradients corresponding to the trainble parameters of the Cell, each
|
|
3812
|
-
element in the tuple shuould be in the format of `(param_name, gradient)`.
|
|
3813
|
-
- The `backward_hook` should have the following signature:
|
|
3814
|
-
backward_hook(parameters) -> New gradients.
|
|
3815
|
-
|
|
3816
|
-
Args:
|
|
3817
|
-
forward_hook (function, optional): Python function or ``None``, Forward hook function. Default: ``None``
|
|
3818
|
-
backward_hook (function, optional): Python function or ``None``, Backward hook function. Default ``None``
|
|
3819
|
-
all (bool, optional): bool, whether to set hooks for all sub cells recursively. Default: ``False``
|
|
3820
|
-
|
|
3821
|
-
Returns:
|
|
3822
|
-
None
|
|
3823
|
-
|
|
3824
|
-
Raises:
|
|
3825
|
-
RuntimeError: If the `forward_hook` or `backward_hook ` has unspoorted syntax under GRAPH MODE.
|
|
3826
|
-
TypeError: If the `forward_hook` or `backward_hook` is not defined as required.
|
|
3827
|
-
|
|
3828
|
-
Supported Platforms:
|
|
3829
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
3830
|
-
|
|
3831
|
-
Examples:
|
|
3832
|
-
>>> import mindspore as ms
|
|
3833
|
-
>>> from mindspore import Tensor, nn, ops, Parameter
|
|
3834
|
-
>>>
|
|
3835
|
-
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
3836
|
-
>>> def parameter_hook(parameters):
|
|
3837
|
-
... print("--- enter parameter hook ---")
|
|
3838
|
-
... for name, param in parameters:
|
|
3839
|
-
... print (name, param)
|
|
3840
|
-
... print("--- leave parameter hook ---")
|
|
3841
|
-
...
|
|
3842
|
-
>>> def gradient_hook(gradients):
|
|
3843
|
-
... print("--- enter gradient hook ---")
|
|
3844
|
-
... outs = []
|
|
3845
|
-
... for name, gradient in gradients:
|
|
3846
|
-
... print(name, gradient)
|
|
3847
|
-
... outs.append(gradient * 2) # double gradient
|
|
3848
|
-
... print("--- leave gradient hook ---")
|
|
3849
|
-
... return outs
|
|
3850
|
-
...
|
|
3851
|
-
>>> class Net(nn.Cell):
|
|
3852
|
-
... def __init__(self)
|
|
3853
|
-
... super(Net, self).__init__()
|
|
3854
|
-
... self.w = Parameter(Tensor(np.array([3.0], np.float32)), name='w')
|
|
3855
|
-
... def construct(self, x):
|
|
3856
|
-
... return self.w * x
|
|
3857
|
-
...
|
|
3858
|
-
>>> grad = ops.GradOperation(get_by_list=True)
|
|
3859
|
-
>>> net = Net()
|
|
3860
|
-
>>> net._register_parameters_hook(forward_hook=parameter_hook, backward_hook=gradient_hook)
|
|
3861
|
-
>>> x = Tensor(np.array([4.0]).astype(np.float32))
|
|
3862
|
-
>>> output = grad(net, net.trainable_params())(x)
|
|
3863
|
-
--- enter parameter hook ---
|
|
3864
|
-
w
|
|
3865
|
-
Tensor(shape=[1], dtype=Float32, value=[ 3.00000000e+00])
|
|
3866
|
-
--- leave parameter hook ---
|
|
3867
|
-
--- enter gradient hook ---
|
|
3868
|
-
w
|
|
3869
|
-
Tensor(shape=[1], dtype=Float32, value=[ 4.00000000e+00])
|
|
3870
|
-
--- leave gradient hook ---
|
|
3871
|
-
>>> print("doubled grad: ", output)
|
|
3872
|
-
doubled grad: (Tensor(shape=[1], dtype=Float32, value=[ 8.00000000e+00]),)
|
|
3873
|
-
"""
|
|
3874
|
-
if not all:
|
|
3875
|
-
self._parameters_forward_hook = forward_hook
|
|
3876
|
-
self._parameters_backward_hook = backward_hook
|
|
3877
|
-
else:
|
|
3878
|
-
for _, cell in self.cells_and_names():
|
|
3879
|
-
cell._parameters_forward_hook = forward_hook
|
|
3880
|
-
cell._parameters_backward_hook = backward_hook
|
|
3704
|
+
Set jit graph name.
|
|
3705
|
+
"""
|
|
3706
|
+
self._jit_graph_name = key
|
|
3881
3707
|
|
|
3882
3708
|
|
|
3883
3709
|
class GraphCell(Cell):
|