mindspore 2.4.10__cp39-cp39-win_amd64.whl → 2.6.0rc1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +46 -197
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +217 -98
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +435 -371
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +951 -1992
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +314 -566
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +182 -116
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +157 -117
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +796 -759
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +921 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1370 -189
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +22 -17
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +17 -13
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +1 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +365 -363
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +27 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
- mindspore/ops/auto_generate/gen_extend_func.py +764 -124
- mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
- mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4501 -3802
- mindspore/ops/function/nn_func.py +1726 -620
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +440 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +24 -17
- mindspore/ops/functional.py +22 -7
- mindspore/ops/functional_overload.py +1440 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +13 -7
- mindspore/ops/operations/_custom_ops_utils.py +247 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +232 -78
- mindspore/ops/operations/debug_ops.py +153 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +210 -498
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1888 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +152 -34
- mindspore/parallel/_cell_wrapper.py +130 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +698 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +259 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -58
- mindspore/parallel/transform_safetensors.py +363 -305
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +409 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +88 -25
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +184 -113
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/utils/utils.py +138 -4
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +562 -393
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2020-
|
|
1
|
+
# Copyright 2020-2025 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -15,13 +15,26 @@
|
|
|
15
15
|
"""cell"""
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
|
|
18
|
-
import gc
|
|
19
18
|
import inspect
|
|
20
19
|
import os
|
|
21
20
|
import time
|
|
22
|
-
|
|
23
|
-
import
|
|
24
|
-
|
|
21
|
+
import warnings
|
|
22
|
+
import itertools
|
|
23
|
+
from collections import OrderedDict, namedtuple
|
|
24
|
+
from typing import (
|
|
25
|
+
Dict,
|
|
26
|
+
Optional,
|
|
27
|
+
Set,
|
|
28
|
+
Callable,
|
|
29
|
+
List,
|
|
30
|
+
Tuple,
|
|
31
|
+
Iterator,
|
|
32
|
+
Any,
|
|
33
|
+
TypeVar,
|
|
34
|
+
Mapping
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
import mindspore as ms
|
|
25
38
|
from mindspore._checkparam import args_type_check, check_hook_fn
|
|
26
39
|
from mindspore.common._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic
|
|
27
40
|
from mindspore import log as logger
|
|
@@ -34,19 +47,62 @@ from mindspore import _checkparam as Validator
|
|
|
34
47
|
from mindspore.common import dtype as mstype
|
|
35
48
|
from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache, \
|
|
36
49
|
_no_grad
|
|
37
|
-
from mindspore.common.api import
|
|
50
|
+
from mindspore.common.api import _convert_python_data, _get_args_for_run_predict
|
|
38
51
|
from mindspore.common.api import _process_dyn_args, _generate_dyn_compile_args
|
|
39
|
-
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
52
|
+
from mindspore.common.parameter import _Buffer, Parameter, ParameterTuple
|
|
40
53
|
from mindspore.common.tensor import Tensor
|
|
41
54
|
from mindspore.ops.operations import Cast
|
|
42
55
|
from mindspore.ops.primitive import Primitive
|
|
43
56
|
from mindspore.ops.operations import _inner_ops as inner
|
|
44
57
|
from mindspore.parallel.shard import Shard
|
|
58
|
+
from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
|
|
45
59
|
from mindspore._check_jit_forbidden_api import jit_forbidden_register
|
|
46
60
|
from mindspore.common._decorator import deprecated
|
|
47
61
|
from mindspore.common._register_for_recompute import recompute_registry
|
|
48
62
|
|
|
49
63
|
|
|
64
|
+
__all__ = [
|
|
65
|
+
"register_cell_buffer_registration_hook",
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
_global_buffer_registration_hooks: Dict[int, Callable] = OrderedDict()
|
|
69
|
+
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class _IncompatibleKeys(namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),):
|
|
73
|
+
def __repr__(self):
|
|
74
|
+
if not self.missing_keys and not self.unexpected_keys:
|
|
75
|
+
return "<All keys matched successfully>"
|
|
76
|
+
return super().__repr__()
|
|
77
|
+
|
|
78
|
+
__str__ = __repr__
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def register_cell_buffer_registration_hook(hook: Callable[..., None],):
|
|
82
|
+
r"""Register a buffer registration hook common to all cells.
|
|
83
|
+
|
|
84
|
+
.. warning ::
|
|
85
|
+
|
|
86
|
+
This adds global state to the `nn.Cell` cell
|
|
87
|
+
|
|
88
|
+
The hook will be called every time :func:`register_buffer` is invoked.
|
|
89
|
+
It should have the following signature::
|
|
90
|
+
|
|
91
|
+
hook(cell, name, buffer) -> None or new buffer
|
|
92
|
+
|
|
93
|
+
The hook can modify the input or return a single modified value in the hook.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
A handle that can be used to remove the added hook by calling
|
|
97
|
+
`handle.remove()`.
|
|
98
|
+
"""
|
|
99
|
+
from mindspore.utils.hooks import _RemovableHandle
|
|
100
|
+
handle = _RemovableHandle(_global_buffer_registration_hooks)
|
|
101
|
+
_global_buffer_registration_hooks[handle.id] = hook
|
|
102
|
+
return handle
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
|
|
50
106
|
class Cell(Cell_):
|
|
51
107
|
"""
|
|
52
108
|
The basic building block of neural networks in MindSpore. The model or neural network layer should inherit this
|
|
@@ -60,7 +116,7 @@ class Cell(Cell_):
|
|
|
60
116
|
.. note::
|
|
61
117
|
Cell is the inference mode by default. For a class that inherits a Cell,
|
|
62
118
|
if the training and inference have different structures, the subclass performs the inference branch by default.
|
|
63
|
-
To set the training mode, refer to
|
|
119
|
+
To set the training mode, refer to :func:`mindspore.nn.Cell.set_train` .
|
|
64
120
|
|
|
65
121
|
.. warning::
|
|
66
122
|
In the subclass of Cell, it's not allowed to define a method named 'cast' and not allowed to define an attribute
|
|
@@ -105,8 +161,11 @@ class Cell(Cell_):
|
|
|
105
161
|
'_func_graph_flags', '_parameter_layout_dict', '_params_list', '_phase', '_bprop_debug',
|
|
106
162
|
'_forward_pre_hook', '_forward_hook', '_backward_pre_hook', '_backward_hook',
|
|
107
163
|
'_cell_backward_pre_hook', '_cell_backward_hook', '_is_run', '_param_prefix',
|
|
108
|
-
'_attr_synced', 'pynative', 'requires_grad', 'cell_type'
|
|
164
|
+
'_attr_synced', 'pynative', 'requires_grad', 'cell_type',
|
|
165
|
+
'_parameters_forward_hook', '_parameters_backward_hook']
|
|
109
166
|
total_instance_count = 0
|
|
167
|
+
_buffers: Dict[str, Optional[Tensor]]
|
|
168
|
+
_non_persistent_buffers_set: Set[str]
|
|
110
169
|
|
|
111
170
|
def __init__(self, auto_prefix=True, flags=None):
|
|
112
171
|
Cell_.__init__(self, self._cell_tag)
|
|
@@ -114,10 +173,17 @@ class Cell(Cell_):
|
|
|
114
173
|
self.instance_count = Cell.total_instance_count
|
|
115
174
|
self._params = OrderedDict()
|
|
116
175
|
self._cells = OrderedDict()
|
|
176
|
+
super().__setattr__("_buffers", {})
|
|
177
|
+
super().__setattr__("_non_persistent_buffers_set", set())
|
|
178
|
+
super().__setattr__("_state_dict_hooks", OrderedDict())
|
|
179
|
+
super().__setattr__("_state_dict_pre_hooks", OrderedDict())
|
|
180
|
+
super().__setattr__("_load_state_dict_pre_hooks", OrderedDict())
|
|
181
|
+
super().__setattr__("_load_state_dict_post_hooks", OrderedDict())
|
|
117
182
|
self._params_list = OrderedDict()
|
|
118
183
|
self._primitives = OrderedDict()
|
|
119
184
|
self.training = False
|
|
120
185
|
self.requires_grad = False
|
|
186
|
+
self.is_top_cell = False
|
|
121
187
|
self.pynative = False
|
|
122
188
|
self._attr_synced = False
|
|
123
189
|
self._param_prefix = ''
|
|
@@ -134,8 +200,8 @@ class Cell(Cell_):
|
|
|
134
200
|
cells_compile_cache[id(self)] = self.compile_cache
|
|
135
201
|
self.parameter_broadcast_done = False
|
|
136
202
|
self._id = 1
|
|
137
|
-
self.
|
|
138
|
-
self.
|
|
203
|
+
self._exist_objs = None
|
|
204
|
+
self._exist_names = None
|
|
139
205
|
self._recompute_cell = None
|
|
140
206
|
self.mixed_precision_type = None
|
|
141
207
|
self.sig = inspect.signature(self.construct)
|
|
@@ -143,7 +209,8 @@ class Cell(Cell_):
|
|
|
143
209
|
|
|
144
210
|
# call gc to release GE session resources used by non-used cell objects
|
|
145
211
|
if os.getenv('GC_COLLECT_IN_CELL') == '1':
|
|
146
|
-
|
|
212
|
+
logger.warning("The convenient environment 'GC_COLLECT_IN_CELL' is deprecated from version 2.5 "
|
|
213
|
+
"and will be removed in a future version.")
|
|
147
214
|
|
|
148
215
|
if flags:
|
|
149
216
|
self.add_flags(**flags)
|
|
@@ -158,6 +225,10 @@ class Cell(Cell_):
|
|
|
158
225
|
self._cell_backward_hook = None
|
|
159
226
|
self._is_recursion_hook = False
|
|
160
227
|
|
|
228
|
+
# parameters hook
|
|
229
|
+
self._parameters_forward_hook = None
|
|
230
|
+
self._parameters_backward_hook = None
|
|
231
|
+
|
|
161
232
|
self.cell_type = None
|
|
162
233
|
self.cast = Cast()
|
|
163
234
|
self._has_config_recompute = False
|
|
@@ -202,6 +273,21 @@ class Cell(Cell_):
|
|
|
202
273
|
def cell_init_args(self):
|
|
203
274
|
return self._cell_init_args
|
|
204
275
|
|
|
276
|
+
@property
|
|
277
|
+
def exist_names(self):
|
|
278
|
+
"""
|
|
279
|
+
Get exist parameter names adding by tuple or list of parameter.
|
|
280
|
+
"""
|
|
281
|
+
if self._exist_names is None:
|
|
282
|
+
self._exist_names = set("")
|
|
283
|
+
return self._exist_names
|
|
284
|
+
|
|
285
|
+
@property
|
|
286
|
+
def exist_objs(self):
|
|
287
|
+
if self._exist_objs is None:
|
|
288
|
+
self._exist_objs = set()
|
|
289
|
+
return self._exist_objs
|
|
290
|
+
|
|
205
291
|
@property
|
|
206
292
|
def param_prefix(self):
|
|
207
293
|
"""
|
|
@@ -230,11 +316,6 @@ class Cell(Cell_):
|
|
|
230
316
|
def bprop_debug(self):
|
|
231
317
|
"""
|
|
232
318
|
Get whether cell custom bprop debug is enabled.
|
|
233
|
-
|
|
234
|
-
Tutorial Examples:
|
|
235
|
-
- `Custom Neural Network Layers - Custom Cell Reverse
|
|
236
|
-
<https://mindspore.cn/docs/en/master/model_train/custom_program/network_custom.html
|
|
237
|
-
#custom-cell-reverse>`_
|
|
238
319
|
"""
|
|
239
320
|
return self._bprop_debug
|
|
240
321
|
|
|
@@ -351,8 +432,6 @@ class Cell(Cell_):
|
|
|
351
432
|
raise ValueError("For 'Cell', the property 'pipeline_stage' "
|
|
352
433
|
"can not be less than 0, but got {}".format(value))
|
|
353
434
|
self._pipeline_stage = value
|
|
354
|
-
for item in self.trainable_params():
|
|
355
|
-
item.add_pipeline_stage(value)
|
|
356
435
|
|
|
357
436
|
@property
|
|
358
437
|
def pipeline_segment(self):
|
|
@@ -388,6 +467,374 @@ class Cell(Cell_):
|
|
|
388
467
|
def enable_backward_hook(self):
|
|
389
468
|
return self._enable_backward_hook
|
|
390
469
|
|
|
470
|
+
@jit_forbidden_register
|
|
471
|
+
def register_buffer(
|
|
472
|
+
self, name: str, tensor: Optional[Tensor], persistent: bool = True
|
|
473
|
+
) -> None:
|
|
474
|
+
r"""Add a buffer to the cell.
|
|
475
|
+
|
|
476
|
+
This is typically used to register a buffer that should not to be
|
|
477
|
+
considered a model parameter. For example, BatchNorm's `running_mean`
|
|
478
|
+
is not a parameter, but is part of the cell's state. Buffers, by
|
|
479
|
+
default, are persistent and will be saved alongside parameters. This
|
|
480
|
+
behavior can be changed by setting `persistent` to ``False`` . The
|
|
481
|
+
only difference between a persistent buffer and a non-persistent buffer
|
|
482
|
+
is that the latter will not be a part of this cell's :attr:`state_dict` .
|
|
483
|
+
|
|
484
|
+
Buffers can be accessed as attributes using given names.
|
|
485
|
+
|
|
486
|
+
Args:
|
|
487
|
+
name (str): name of the buffer. The buffer can be accessed
|
|
488
|
+
from this cell using the given name.
|
|
489
|
+
tensor (Tensor): Buffer to be registered. If ``None`` ,
|
|
490
|
+
the buffer is not included in the cell's :attr:`state_dict` .
|
|
491
|
+
persistent (bool, optional): Whether the buffer is part of this cell's :attr:`state_dict`. Default ``True``.
|
|
492
|
+
|
|
493
|
+
Examples:
|
|
494
|
+
>>> import mindspore
|
|
495
|
+
...
|
|
496
|
+
>>> class Net(mindspore.nn.Cell):
|
|
497
|
+
... def __init__(self):
|
|
498
|
+
... super().__init__()
|
|
499
|
+
... self.register_buffer("buffer0", mindspore.tensor([1, 2, 3]))
|
|
500
|
+
...
|
|
501
|
+
... def construct(self, x):
|
|
502
|
+
... return x + self.net_buffer
|
|
503
|
+
...
|
|
504
|
+
>>> net = Net()
|
|
505
|
+
>>> net.register_buffer("buffer0", mindspore.tensor([4, 5, 6]))
|
|
506
|
+
>>> print(net.buffer0)
|
|
507
|
+
[4 5 6]
|
|
508
|
+
"""
|
|
509
|
+
|
|
510
|
+
if "_buffers" not in self.__dict__:
|
|
511
|
+
raise AttributeError("cannot assign buffer before Cell.__init__() call")
|
|
512
|
+
if not isinstance(name, str):
|
|
513
|
+
raise TypeError(
|
|
514
|
+
f"buffer name should be a string.But got this type: {type(name)}"
|
|
515
|
+
)
|
|
516
|
+
if "." in name:
|
|
517
|
+
raise KeyError('buffer name can\'t contain "."')
|
|
518
|
+
if name == "":
|
|
519
|
+
raise KeyError('buffer name can\'t be empty string ""')
|
|
520
|
+
if hasattr(self, name) and name not in self._buffers:
|
|
521
|
+
raise KeyError(f"attribute '{name}' already exists")
|
|
522
|
+
if tensor is not None and not isinstance(tensor, Tensor):
|
|
523
|
+
raise TypeError(
|
|
524
|
+
f"cannot assign '{type(tensor)}' object to buffer '{name}' "
|
|
525
|
+
"(mindspore Tensor or None required)"
|
|
526
|
+
)
|
|
527
|
+
for hook in _global_buffer_registration_hooks.values():
|
|
528
|
+
output = hook(self, name, tensor)
|
|
529
|
+
if output is not None:
|
|
530
|
+
tensor = output
|
|
531
|
+
if tensor is not None:
|
|
532
|
+
tensor._is_buffer = True
|
|
533
|
+
self._buffers[name] = tensor
|
|
534
|
+
if persistent:
|
|
535
|
+
self._non_persistent_buffers_set.discard(name)
|
|
536
|
+
else:
|
|
537
|
+
self._non_persistent_buffers_set.add(name)
|
|
538
|
+
|
|
539
|
+
@jit_forbidden_register
|
|
540
|
+
def get_buffer(self, target: str) -> "Tensor":
|
|
541
|
+
"""Return the buffer given by `target` if it exists, otherwise throw an error.
|
|
542
|
+
|
|
543
|
+
See the docstring for `get_sub_cell` for a more detailed
|
|
544
|
+
explanation of this method's functionality as well as how to
|
|
545
|
+
correctly specify `target` .
|
|
546
|
+
|
|
547
|
+
Args:
|
|
548
|
+
target (str): The fully-qualified string name of the buffer
|
|
549
|
+
to look for. (See `get_sub_cell` for how to specify a
|
|
550
|
+
fully-qualified string.)
|
|
551
|
+
|
|
552
|
+
Returns:
|
|
553
|
+
Tensor
|
|
554
|
+
|
|
555
|
+
Examples:
|
|
556
|
+
>>> import mindspore
|
|
557
|
+
...
|
|
558
|
+
...
|
|
559
|
+
>>> class NetC(mindspore.nn.Cell):
|
|
560
|
+
... def __init__(self):
|
|
561
|
+
... super().__init__()
|
|
562
|
+
... self.register_buffer("buffer_c", mindspore.tensor([0, 0, 0]))
|
|
563
|
+
...
|
|
564
|
+
... def construct(self, x):
|
|
565
|
+
... return x + self.buffer_c
|
|
566
|
+
...
|
|
567
|
+
...
|
|
568
|
+
>>> class NetB(mindspore.nn.Cell):
|
|
569
|
+
... def __init__(self, net_c):
|
|
570
|
+
... super().__init__()
|
|
571
|
+
... self.net_c = net_c
|
|
572
|
+
... self.register_buffer("buffer_b", mindspore.tensor([1, 2, 3]))
|
|
573
|
+
...
|
|
574
|
+
... def construct(self, x):
|
|
575
|
+
... return self.net_c(x) + self.buffer_b
|
|
576
|
+
...
|
|
577
|
+
...
|
|
578
|
+
>>> class NetA(mindspore.nn.Cell):
|
|
579
|
+
... def __init__(self, net_b):
|
|
580
|
+
... super().__init__()
|
|
581
|
+
... self.net_b = net_b
|
|
582
|
+
... self.register_buffer("buffer_a", mindspore.tensor([4, 5, 6]))
|
|
583
|
+
...
|
|
584
|
+
... def construct(self, x):
|
|
585
|
+
... return self.net_b(x) + self.buffer_a
|
|
586
|
+
...
|
|
587
|
+
...
|
|
588
|
+
>>> net_c = NetC()
|
|
589
|
+
>>> net_b = NetB(net_c)
|
|
590
|
+
>>> net_a = NetA(net_b)
|
|
591
|
+
>>> buffer_c = net_a.get_buffer("net_b.net_c.buffer_c")
|
|
592
|
+
>>> print(f'buffer_c is {buffer_c}')
|
|
593
|
+
buffer_c is [0 0 0]
|
|
594
|
+
|
|
595
|
+
"""
|
|
596
|
+
cell_path, _, buffer_name = target.rpartition(".")
|
|
597
|
+
|
|
598
|
+
cell = self.get_sub_cell(cell_path)
|
|
599
|
+
|
|
600
|
+
if not hasattr(cell, buffer_name):
|
|
601
|
+
raise AttributeError(
|
|
602
|
+
cell._get_name() + " has no attribute `" + buffer_name + "`"
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
buffer = getattr(cell, buffer_name)
|
|
606
|
+
|
|
607
|
+
if buffer_name not in cell._buffers:
|
|
608
|
+
raise AttributeError("`" + buffer_name + "` is not a buffer")
|
|
609
|
+
|
|
610
|
+
return buffer
|
|
611
|
+
|
|
612
|
+
@jit_forbidden_register
|
|
613
|
+
def named_buffers(
|
|
614
|
+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
|
|
615
|
+
) -> Iterator[Tuple[str, Tensor]]:
|
|
616
|
+
r"""Return an iterator over cell buffers, yielding both the name of the buffer as well as the buffer itself.
|
|
617
|
+
|
|
618
|
+
Args:
|
|
619
|
+
prefix (str, optional): prefix to prepend to all buffer names. Default ``""``.
|
|
620
|
+
recurse (bool, optional): if ``True`` , then yields buffers of this cell
|
|
621
|
+
and all sub cells. Otherwise, yields only buffers that
|
|
622
|
+
are direct members of this cell. Default ``True``.
|
|
623
|
+
remove_duplicate (bool, optional): Whether to remove the duplicated buffers in the result. Default ``True``.
|
|
624
|
+
|
|
625
|
+
Returns:
|
|
626
|
+
Iterator[Tuple[str, Tensor]], an iterator of tuple containing the name and buffer.
|
|
627
|
+
|
|
628
|
+
Examples:
|
|
629
|
+
>>> import mindspore
|
|
630
|
+
...
|
|
631
|
+
...
|
|
632
|
+
>>> class NetB(mindspore.nn.Cell):
|
|
633
|
+
... def __init__(self):
|
|
634
|
+
... super().__init__()
|
|
635
|
+
... self.register_buffer("buffer_b", mindspore.tensor([1, 2, 3]))
|
|
636
|
+
...
|
|
637
|
+
... def construct(self, x):
|
|
638
|
+
... return x + self.buffer_b
|
|
639
|
+
...
|
|
640
|
+
...
|
|
641
|
+
>>> class NetA(mindspore.nn.Cell):
|
|
642
|
+
... def __init__(self, net_b):
|
|
643
|
+
... super().__init__()
|
|
644
|
+
... self.net_b = net_b
|
|
645
|
+
... self.register_buffer("buffer_a", mindspore.tensor([4, 5, 6]))
|
|
646
|
+
...
|
|
647
|
+
... def construct(self, x):
|
|
648
|
+
... return self.net_b(x) + self.buffer_a
|
|
649
|
+
...
|
|
650
|
+
...
|
|
651
|
+
>>> net_b = NetB()
|
|
652
|
+
>>> net_a = NetA(net_b)
|
|
653
|
+
>>>
|
|
654
|
+
>>> for name, buffer in net_a.named_buffers():
|
|
655
|
+
... print(f'buffer name is {name}, buffer is {buffer}')
|
|
656
|
+
buffer name is buffer_a, buffer is [4 5 6]
|
|
657
|
+
buffer name is net_b.buffer_b, buffer is [1 2 3]
|
|
658
|
+
|
|
659
|
+
"""
|
|
660
|
+
gen = self._named_members(
|
|
661
|
+
lambda cell: cell._buffers.items(),
|
|
662
|
+
prefix=prefix,
|
|
663
|
+
recurse=recurse,
|
|
664
|
+
remove_duplicate=remove_duplicate,
|
|
665
|
+
)
|
|
666
|
+
yield from gen
|
|
667
|
+
|
|
668
|
+
@jit_forbidden_register
|
|
669
|
+
def buffers(self, recurse: bool = True) -> Iterator[Tensor]:
|
|
670
|
+
r"""Return an iterator over cell buffers.
|
|
671
|
+
|
|
672
|
+
Args:
|
|
673
|
+
recurse (bool, optional): If ``True`` , then yields buffers of this cell
|
|
674
|
+
and all sub cells. Otherwise, yields only buffers that
|
|
675
|
+
are direct members of this cell. Default ``True``.
|
|
676
|
+
|
|
677
|
+
Returns:
|
|
678
|
+
Iterator[Tensor], an iterator of buffer.
|
|
679
|
+
|
|
680
|
+
Examples:
|
|
681
|
+
>>> import mindspore
|
|
682
|
+
...
|
|
683
|
+
...
|
|
684
|
+
>>> class NetB(mindspore.nn.Cell):
|
|
685
|
+
... def __init__(self):
|
|
686
|
+
... super().__init__()
|
|
687
|
+
... self.register_buffer("buffer_b", mindspore.tensor([1, 2, 3]))
|
|
688
|
+
...
|
|
689
|
+
... def construct(self, x):
|
|
690
|
+
... return x + self.buffer_b
|
|
691
|
+
...
|
|
692
|
+
...
|
|
693
|
+
>>> class NetA(mindspore.nn.Cell):
|
|
694
|
+
... def __init__(self, net_b):
|
|
695
|
+
... super().__init__()
|
|
696
|
+
... self.net_b = net_b
|
|
697
|
+
... self.register_buffer("buffer_a", mindspore.tensor([4, 5, 6]))
|
|
698
|
+
...
|
|
699
|
+
... def construct(self, x):
|
|
700
|
+
... return self.net_b(x) + self.buffer_a
|
|
701
|
+
...
|
|
702
|
+
...
|
|
703
|
+
>>> net_b = NetB()
|
|
704
|
+
>>> net_a = NetA(net_b)
|
|
705
|
+
>>>
|
|
706
|
+
>>> for buffer in net_a.buffers():
|
|
707
|
+
... print(f'buffer is {buffer}')
|
|
708
|
+
buffer is [4 5 6]
|
|
709
|
+
buffer is [1 2 3]
|
|
710
|
+
|
|
711
|
+
"""
|
|
712
|
+
for _, buf in self.named_buffers(recurse=recurse):
|
|
713
|
+
yield buf
|
|
714
|
+
|
|
715
|
+
def _named_members(self, get_members_fn, prefix="", recurse=True, remove_duplicate: bool = True):
|
|
716
|
+
r"""Help yield various names + members of cells."""
|
|
717
|
+
memo = set()
|
|
718
|
+
cells = (
|
|
719
|
+
self.cells_and_names(name_prefix=prefix)
|
|
720
|
+
if recurse
|
|
721
|
+
else [(prefix, self)]
|
|
722
|
+
)
|
|
723
|
+
for cell_prefix, cell in cells:
|
|
724
|
+
members = get_members_fn(cell)
|
|
725
|
+
for k, v in members:
|
|
726
|
+
if v is None or v in memo:
|
|
727
|
+
continue
|
|
728
|
+
if remove_duplicate:
|
|
729
|
+
memo.add(v)
|
|
730
|
+
name = cell_prefix + ("." if cell_prefix else "") + k
|
|
731
|
+
yield name, v
|
|
732
|
+
|
|
733
|
+
@jit_forbidden_register
|
|
734
|
+
def get_sub_cell(self, target: str) -> "Cell":
|
|
735
|
+
"""Return the sub cell given by `target` if it exists, otherwise throw an error.
|
|
736
|
+
|
|
737
|
+
For example, let's say you have an ``nn.Cell`` ``A`` that
|
|
738
|
+
looks like this:
|
|
739
|
+
|
|
740
|
+
.. code-block:: text
|
|
741
|
+
|
|
742
|
+
A(
|
|
743
|
+
(net_b): NetB(
|
|
744
|
+
(net_c): NetC(
|
|
745
|
+
(conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
|
|
746
|
+
)
|
|
747
|
+
(dense): Dense(in_features=100, out_features=200, bias=True)
|
|
748
|
+
)
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
(The diagram shows an ``nn.Cell`` ``A``. ``A`` has a nested
|
|
752
|
+
sub cell ``net_b``, which itself has two sub cells ``net_c``
|
|
753
|
+
and ``dense``. ``net_c`` then has a sub cell ``conv``.)
|
|
754
|
+
|
|
755
|
+
To check whether we have the ``dense`` sub cell, we
|
|
756
|
+
would call `get_sub_cell("net_b.dense")`. To check whether
|
|
757
|
+
we have the ``conv`` sub cell, we would call
|
|
758
|
+
`get_sub_cell("net_b.net_c.conv")`.
|
|
759
|
+
|
|
760
|
+
The runtime of ``get_sub_cell`` is bounded by the degree
|
|
761
|
+
of cell nesting in `target`. A query against
|
|
762
|
+
`name_cells` achieves the same result, but it is O(N) in
|
|
763
|
+
the number of transitive cells. So, for a simple check to see
|
|
764
|
+
if some sub cells exist, ``get_sub_cell`` should always be
|
|
765
|
+
used.
|
|
766
|
+
|
|
767
|
+
Args:
|
|
768
|
+
target (str): The fully-qualified string name of the sub cell
|
|
769
|
+
to look for. (See above example for how to specify a
|
|
770
|
+
fully-qualified string.)
|
|
771
|
+
|
|
772
|
+
Returns:
|
|
773
|
+
Cell
|
|
774
|
+
|
|
775
|
+
Examples:
|
|
776
|
+
>>> import mindspore
|
|
777
|
+
...
|
|
778
|
+
...
|
|
779
|
+
>>> class NetC(mindspore.nn.Cell):
|
|
780
|
+
... def __init__(self):
|
|
781
|
+
... super().__init__()
|
|
782
|
+
... self.register_buffer("buffer_c", mindspore.tensor([0, 0, 0]))
|
|
783
|
+
... self.dense_c = mindspore.nn.Dense(5, 3)
|
|
784
|
+
...
|
|
785
|
+
... def construct(self, x):
|
|
786
|
+
... return self.dense_c(x) + self.buffer_c
|
|
787
|
+
...
|
|
788
|
+
...
|
|
789
|
+
>>> class NetB(mindspore.nn.Cell):
|
|
790
|
+
... def __init__(self, net_c):
|
|
791
|
+
... super().__init__()
|
|
792
|
+
... self.net_c = net_c
|
|
793
|
+
... self.register_buffer("buffer_b", mindspore.tensor([1, 2, 3]))
|
|
794
|
+
...
|
|
795
|
+
... def construct(self, x):
|
|
796
|
+
... return self.net_c(x) + self.buffer_b
|
|
797
|
+
...
|
|
798
|
+
...
|
|
799
|
+
>>> class NetA(mindspore.nn.Cell):
|
|
800
|
+
... def __init__(self, net_b):
|
|
801
|
+
... super().__init__()
|
|
802
|
+
... self.net_b = net_b
|
|
803
|
+
... self.register_buffer("buffer_a", mindspore.tensor([4, 5, 6]))
|
|
804
|
+
...
|
|
805
|
+
... def construct(self, x):
|
|
806
|
+
... return self.net_b(x) + self.buffer_a
|
|
807
|
+
...
|
|
808
|
+
...
|
|
809
|
+
>>> net_c = NetC()
|
|
810
|
+
>>> net_b = NetB(net_c)
|
|
811
|
+
>>> net_a = NetA(net_b)
|
|
812
|
+
>>> net_c = net_a.get_sub_cell("net_b.net_c")
|
|
813
|
+
>>> print(f'net_c is {net_c}')
|
|
814
|
+
net_c is NetC(
|
|
815
|
+
(dense_c): Dense(input_channels=5, output_channels=3, has_bias=True)
|
|
816
|
+
)
|
|
817
|
+
|
|
818
|
+
"""
|
|
819
|
+
if target == "":
|
|
820
|
+
return self
|
|
821
|
+
|
|
822
|
+
atoms: List[str] = target.split(".")
|
|
823
|
+
cell = self
|
|
824
|
+
|
|
825
|
+
for item in atoms:
|
|
826
|
+
if not hasattr(cell, item):
|
|
827
|
+
raise AttributeError(
|
|
828
|
+
cell._get_name() + " has no " "attribute `" + item + "`"
|
|
829
|
+
)
|
|
830
|
+
|
|
831
|
+
cell = getattr(cell, item)
|
|
832
|
+
|
|
833
|
+
if not isinstance(cell, Cell):
|
|
834
|
+
raise AttributeError("`" + item + "` is not " "an nn.Cell")
|
|
835
|
+
|
|
836
|
+
return cell
|
|
837
|
+
|
|
391
838
|
def get_func_graph_proto(self):
|
|
392
839
|
"""Return graph binary proto."""
|
|
393
840
|
exec_id = ".".join([self.phase, str(self.create_time), str(id(self))])
|
|
@@ -398,6 +845,10 @@ class Cell(Cell_):
|
|
|
398
845
|
params = self.__dict__['_params']
|
|
399
846
|
if name in params:
|
|
400
847
|
return params[name]
|
|
848
|
+
if '_buffers' in self.__dict__:
|
|
849
|
+
buffers = self.__dict__['_buffers']
|
|
850
|
+
if name in buffers:
|
|
851
|
+
return buffers[name]
|
|
401
852
|
if '_cells' in self.__dict__:
|
|
402
853
|
cells = self.__dict__['_cells']
|
|
403
854
|
if name in cells:
|
|
@@ -420,6 +871,8 @@ class Cell(Cell_):
|
|
|
420
871
|
def __delattr__(self, name):
|
|
421
872
|
if name in self._params:
|
|
422
873
|
del self._params[name]
|
|
874
|
+
elif name in self._buffers:
|
|
875
|
+
del self._buffers[name]
|
|
423
876
|
elif name in self._cells:
|
|
424
877
|
del self._cells[name]
|
|
425
878
|
elif '_params_list' in self.__dict__ and name in self._params_list:
|
|
@@ -492,14 +945,17 @@ class Cell(Cell_):
|
|
|
492
945
|
if self._forward_pre_hook:
|
|
493
946
|
inputs = self._run_forward_pre_hook(inputs)
|
|
494
947
|
|
|
495
|
-
if self.
|
|
496
|
-
output = self._backward_hook_construct(*inputs, **kwargs)
|
|
497
|
-
elif self._shard_fn is not None:
|
|
948
|
+
if self._shard_fn is not None:
|
|
498
949
|
output = self._shard_fn(*inputs, **kwargs)
|
|
499
|
-
elif
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
950
|
+
elif _pynative_executor.requires_grad():
|
|
951
|
+
if self._backward_hook:
|
|
952
|
+
output = self._backward_hook_construct(*inputs, **kwargs)
|
|
953
|
+
elif self._recompute_cell is not None:
|
|
954
|
+
output = self._recompute_cell(*inputs, **kwargs)
|
|
955
|
+
elif self.has_bprop:
|
|
956
|
+
output = self._call_custom_bprop(*inputs, **kwargs)
|
|
957
|
+
else:
|
|
958
|
+
output = self.construct(*inputs, **kwargs)
|
|
503
959
|
else:
|
|
504
960
|
output = self.construct(*inputs, **kwargs)
|
|
505
961
|
|
|
@@ -590,6 +1046,89 @@ class Cell(Cell_):
|
|
|
590
1046
|
for prim in all_prims:
|
|
591
1047
|
prim.add_prim_attr("strategy_gen_mode", "data_parallel")
|
|
592
1048
|
|
|
1049
|
+
def offload(self, backward_prefetch="Auto"):
|
|
1050
|
+
"""
|
|
1051
|
+
Set the cell offload. All primitive ops in the cell will be set offload. For the intermediate
|
|
1052
|
+
activations calculated by these primitive ops, we will not save them in the forward pass, but
|
|
1053
|
+
offload them and onload them in the backward pass.
|
|
1054
|
+
|
|
1055
|
+
Note:
|
|
1056
|
+
- If Cell.offload is called, the mode should be set to "GRAPH_MODE".
|
|
1057
|
+
- If Cell.offload is called, lazyinline should be enabled.
|
|
1058
|
+
|
|
1059
|
+
Args:
|
|
1060
|
+
backward_prefetch(Union[str, int], optional): The timing for prefetching activations in advance in backward
|
|
1061
|
+
pass. Default: ``"Auto"``. If set it to ``"Auto"``, framework
|
|
1062
|
+
will start to prefetch activations one operator in advance.
|
|
1063
|
+
If set it to a positive int value, framework will start to
|
|
1064
|
+
prefetch activations ``backward_prefetch`` operators in
|
|
1065
|
+
advance, such as 1, 20, 100.
|
|
1066
|
+
Examples:
|
|
1067
|
+
>>> import mindspore.nn as nn
|
|
1068
|
+
>>> from mindspore import ops
|
|
1069
|
+
>>> from mindspore.common import Tensor, Parameter
|
|
1070
|
+
>>> from mindspore.common.lazy_inline import lazy_inline
|
|
1071
|
+
>>>
|
|
1072
|
+
>>> class Block(nn.Cell):
|
|
1073
|
+
... def __init__(self):
|
|
1074
|
+
... super(Block, self).__init__()
|
|
1075
|
+
... self.transpose1 = ops.Transpose()
|
|
1076
|
+
... self.transpose2 = ops.Transpose()
|
|
1077
|
+
... self.transpose3 = ops.Transpose()
|
|
1078
|
+
... self.transpose4 = ops.Transpose()
|
|
1079
|
+
... self.real_div1 = ops.RealDiv()
|
|
1080
|
+
... self.real_div2 = ops.RealDiv()
|
|
1081
|
+
... self.batch_matmul1 = ops.BatchMatMul()
|
|
1082
|
+
... self.batch_matmul2 = ops.BatchMatMul()
|
|
1083
|
+
... self.softmax = ops.Softmax(-1)
|
|
1084
|
+
... self.expand_dims = ops.ExpandDims()
|
|
1085
|
+
... self.sub = ops.Sub()
|
|
1086
|
+
... self.y = Parameter(Tensor(np.ones((1024, 128, 128)).astype(np.float32)))
|
|
1087
|
+
... def construct(self, x):
|
|
1088
|
+
... transpose1 = self.transpose1(x, (0, 2, 1, 3))
|
|
1089
|
+
... real_div1 = self.real_div1(transpose1, Tensor(2.37891))
|
|
1090
|
+
... transpose2 = self.transpose2(x, (0, 2, 3, 1))
|
|
1091
|
+
... real_div2 = self.real_div2(transpose2, Tensor(2.37891))
|
|
1092
|
+
... batch_matmul1 = self.batch_matmul1(real_div1, real_div2)
|
|
1093
|
+
... expand_dims = self.expand_dims(self.y, 1)
|
|
1094
|
+
... sub = self.sub(Tensor([1.0]), expand_dims)
|
|
1095
|
+
... soft_max = self.softmax(sub)
|
|
1096
|
+
... transpose3 = self.transpose3(x, (0, 2, 1, 3))
|
|
1097
|
+
... batch_matmul2 = self.batch_matmul2(soft_max[0], transpose3)
|
|
1098
|
+
... transpose4 = self.transpose4(batch_matmul2, (0, 2, 1, 3))
|
|
1099
|
+
... return transpose4
|
|
1100
|
+
>>>
|
|
1101
|
+
>>> class OuterBlock(nn.Cell):
|
|
1102
|
+
... @lazy_inline
|
|
1103
|
+
... def __init__(self):
|
|
1104
|
+
... super(OuterBlock, self).__init__()
|
|
1105
|
+
... self.block = Block()
|
|
1106
|
+
... def construct(self, x):
|
|
1107
|
+
... return self.block(x)
|
|
1108
|
+
>>>
|
|
1109
|
+
>>> class Nets(nn.Cell):
|
|
1110
|
+
... def __init__(self):
|
|
1111
|
+
... super(Nets, self).__init__()
|
|
1112
|
+
... self.blocks = nn.CellList()
|
|
1113
|
+
... for _ in range(3):
|
|
1114
|
+
... b = OuterBlock()
|
|
1115
|
+
... b.offload()
|
|
1116
|
+
... self.blocks.append(b)
|
|
1117
|
+
... def construct(self, x):
|
|
1118
|
+
... out = x
|
|
1119
|
+
... for i in range(3):
|
|
1120
|
+
... out = self.blocks[i](out)
|
|
1121
|
+
... return out
|
|
1122
|
+
"""
|
|
1123
|
+
if context._get_mode() == context.PYNATIVE_MODE:
|
|
1124
|
+
raise ValueError("The Cell offload does not support PyNative mode now.")
|
|
1125
|
+
if isinstance(backward_prefetch, str):
|
|
1126
|
+
Validator.check_string(backward_prefetch, ['Auto'], 'backward_prefetch', self.cls_name)
|
|
1127
|
+
else:
|
|
1128
|
+
Validator.check_non_negative_int(backward_prefetch)
|
|
1129
|
+
for prim in self._get_prims_recursively():
|
|
1130
|
+
prim._offload(backward_prefetch=backward_prefetch)
|
|
1131
|
+
|
|
593
1132
|
def shard(self, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
|
|
594
1133
|
"""
|
|
595
1134
|
Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be
|
|
@@ -598,7 +1137,7 @@ class Cell(Cell_):
|
|
|
598
1137
|
strategy for others will be set by sharding propagation.
|
|
599
1138
|
in_strategy and out_strategy define the input and output layout respectively.
|
|
600
1139
|
in_strategy/out_strategy should be a tuple, each element of which corresponds to the desired layout of
|
|
601
|
-
this input/output, which can refer to the description of
|
|
1140
|
+
this input/output, which can refer to the description of :func:`mindspore.ops.Primitive.shard`.
|
|
602
1141
|
The parallel strategies of remaining operators are derived from the strategy specified by the input and output.
|
|
603
1142
|
|
|
604
1143
|
Note:
|
|
@@ -618,7 +1157,7 @@ class Cell(Cell_):
|
|
|
618
1157
|
If the parameter name is incorrect or the corresponding parameter
|
|
619
1158
|
has been set, the parameter setting will be ignored.
|
|
620
1159
|
Default: ``None`` .
|
|
621
|
-
device (
|
|
1160
|
+
device (str): Select a certain device target. It is not in use right now.
|
|
622
1161
|
Support [ ``"CPU"`` , ``"GPU"`` , ``"Ascend"`` ]. Default: ``"Ascend"`` .
|
|
623
1162
|
level (int): Option for parallel strategy infer algorithm, namely the object function, maximize computation
|
|
624
1163
|
over communication ratio, maximize speed performance, minimize memory usage etc. It is not in
|
|
@@ -650,10 +1189,8 @@ class Cell(Cell_):
|
|
|
650
1189
|
... x = self.block2_shard(x)
|
|
651
1190
|
... return x
|
|
652
1191
|
"""
|
|
653
|
-
if
|
|
654
|
-
|
|
655
|
-
f"Please check the parallel mode in parallel context.")
|
|
656
|
-
|
|
1192
|
+
if ms.communication.management.get_group_size() == 1:
|
|
1193
|
+
return self
|
|
657
1194
|
shard_fn = Shard()
|
|
658
1195
|
fn = shard_fn(self, in_strategy, out_strategy, parameter_plan, device, level)
|
|
659
1196
|
self._shard_fn = fn
|
|
@@ -756,7 +1293,8 @@ class Cell(Cell_):
|
|
|
756
1293
|
"""
|
|
757
1294
|
Process cell info before call construct
|
|
758
1295
|
"""
|
|
759
|
-
if self.requires_grad:
|
|
1296
|
+
if self.requires_grad and (not _pynative_executor.grad_flag() or _pynative_executor.high_order()):
|
|
1297
|
+
self.is_top_cell = True
|
|
760
1298
|
_pynative_executor.set_grad_flag(True)
|
|
761
1299
|
_pynative_executor.new_graph(self, *args, **kwargs)
|
|
762
1300
|
elif self._dynamic_shape_inputs is not None:
|
|
@@ -770,8 +1308,9 @@ class Cell(Cell_):
|
|
|
770
1308
|
"""
|
|
771
1309
|
Process cell info after call construct
|
|
772
1310
|
"""
|
|
773
|
-
if self.requires_grad:
|
|
1311
|
+
if self.requires_grad and self.is_top_cell:
|
|
774
1312
|
_pynative_executor.end_graph(self, output, *args, **kwargs)
|
|
1313
|
+
self.is_top_cell = False
|
|
775
1314
|
elif self._dynamic_shape_inputs is not None:
|
|
776
1315
|
_pynative_executor.set_cell_use_dynamic_shape_process(False)
|
|
777
1316
|
|
|
@@ -816,52 +1355,41 @@ class Cell(Cell_):
|
|
|
816
1355
|
self._add_attr(key, value)
|
|
817
1356
|
self._attr_synced = True
|
|
818
1357
|
|
|
819
|
-
def
|
|
820
|
-
"""Set attr for
|
|
821
|
-
|
|
822
|
-
params = self.__dict__.get('_params')
|
|
823
|
-
if params is None:
|
|
824
|
-
raise AttributeError("For 'Cell', can not assign params before Cell.__init__() is called.")
|
|
825
|
-
if name in self.__dict__:
|
|
826
|
-
if self.__dict__[name] is not None:
|
|
827
|
-
raise TypeError(f"For 'Cell', the {name} should not be Parameter.")
|
|
828
|
-
del self.__dict__[name]
|
|
829
|
-
if cells and name in cells:
|
|
830
|
-
raise TypeError(f"For 'Cell', the {name} must be Cell, but got Parameter.")
|
|
831
|
-
self.insert_param_to_cell(name, value)
|
|
832
|
-
|
|
833
|
-
def _set_attr_for_parameter_tuple(self, name, value):
|
|
834
|
-
"""Set attr for parameter in ParameterTuple."""
|
|
835
|
-
params = self.__dict__.get('_params')
|
|
836
|
-
params_list = self.__dict__.get('_params_list')
|
|
837
|
-
if params is None:
|
|
838
|
-
raise AttributeError("For 'Cell', can not assign params before Cell.__init__() is called.")
|
|
839
|
-
exist_names = set("")
|
|
840
|
-
exist_objs = set()
|
|
841
|
-
for item in value:
|
|
842
|
-
if item in exist_objs:
|
|
843
|
-
# If there are multiple identical objects, their names only check once.
|
|
844
|
-
continue
|
|
845
|
-
exist_objs.add(item)
|
|
846
|
-
if item.name == PARAMETER_NAME_DEFAULT:
|
|
847
|
-
logger.warning("For 'Cell', the parameter definition is deprecated.\n"
|
|
848
|
-
"Please set a unique name for the parameter in ParameterTuple '{}'.".format(value))
|
|
849
|
-
item.name = item.name + "$" + str(self._id)
|
|
850
|
-
self._id += 1
|
|
851
|
-
self.insert_param_to_cell(item.name, item, check_name_contain_dot=False)
|
|
852
|
-
if item.name in exist_names:
|
|
853
|
-
raise ValueError("The value {} , its name '{}' already exists. "
|
|
854
|
-
"Please set a unique name for the parameter.".format(value, item.name))
|
|
855
|
-
exist_names.add(item.name)
|
|
856
|
-
|
|
857
|
-
if context._get_mode() == context.PYNATIVE_MODE:
|
|
1358
|
+
def _set_attr_for_param_or_param_tuple(self, name, value):
|
|
1359
|
+
"""Set attr for param and tensor."""
|
|
1360
|
+
if isinstance(value, Parameter):
|
|
858
1361
|
if name in self.__dict__:
|
|
859
1362
|
del self.__dict__[name]
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
1363
|
+
self.insert_param_to_cell(name, value)
|
|
1364
|
+
elif isinstance(value, ParameterTuple):
|
|
1365
|
+
exist_names = set("")
|
|
1366
|
+
exist_objs = set()
|
|
1367
|
+
for item in value:
|
|
1368
|
+
if item in exist_objs:
|
|
1369
|
+
# If there are multiple identical objects, their names only check once.
|
|
1370
|
+
continue
|
|
1371
|
+
exist_objs.add(item)
|
|
1372
|
+
if item.name == PARAMETER_NAME_DEFAULT:
|
|
1373
|
+
logger.warning("For 'Cell', the parameter definition is deprecated.\n"
|
|
1374
|
+
"Please set a unique name for the parameter in ParameterTuple '{}'.".format(value))
|
|
1375
|
+
item.name = item.name + "$" + str(self._id)
|
|
1376
|
+
self._id += 1
|
|
1377
|
+
self.insert_param_to_cell(item.name, item, check_name_contain_dot=False)
|
|
1378
|
+
if item.name in exist_names:
|
|
1379
|
+
raise ValueError("The value {} , its name '{}' already exists. "
|
|
1380
|
+
"Please set a unique name for the parameter.".format(value, item.name))
|
|
1381
|
+
exist_names.add(item.name)
|
|
1382
|
+
|
|
1383
|
+
if context._get_mode() == context.PYNATIVE_MODE:
|
|
1384
|
+
if name in self.__dict__:
|
|
1385
|
+
del self.__dict__[name]
|
|
1386
|
+
params = self.__dict__.get('_params')
|
|
1387
|
+
if name in params:
|
|
1388
|
+
del params[name]
|
|
1389
|
+
params_list = self.__dict__.get('_params_list')
|
|
1390
|
+
params_list[name] = value
|
|
1391
|
+
else:
|
|
1392
|
+
object.__setattr__(self, name, value)
|
|
865
1393
|
|
|
866
1394
|
def _set_attr_for_parameter_in_list_or_tuple(self, name, value):
|
|
867
1395
|
"""Set attr for parameter in list or tuple."""
|
|
@@ -874,24 +1402,18 @@ class Cell(Cell_):
|
|
|
874
1402
|
item.name = item.name + "$" + str(self._id)
|
|
875
1403
|
self._id += 1
|
|
876
1404
|
if item.name in self.exist_names:
|
|
877
|
-
raise ValueError("The value {} , its name '{}' already exists. "
|
|
878
|
-
"Please set a unique name for the parameter."
|
|
1405
|
+
raise ValueError(f"The value {value} , its name '{item.name}' already exists. "
|
|
1406
|
+
"Please set a unique name for the parameter.")
|
|
879
1407
|
self.exist_names.add(item.name)
|
|
880
1408
|
object.__setattr__(self, name, value)
|
|
881
1409
|
|
|
882
1410
|
def _set_attr_for_cell(self, name, value):
|
|
883
1411
|
"""Set attr for cell."""
|
|
884
|
-
cells = self.__dict__.get('_cells')
|
|
885
|
-
params = self.__dict__.get('_params')
|
|
886
|
-
if cells is None:
|
|
887
|
-
raise AttributeError("For 'Cell', can not assign cells before Cell.__init__() is called.")
|
|
888
1412
|
if name in self.__dict__:
|
|
889
1413
|
del self.__dict__[name]
|
|
890
|
-
if params and name in params:
|
|
891
|
-
raise TypeError(f"For 'Cell', the {name} must be Parameter, but got Cell.")
|
|
892
1414
|
if self._auto_prefix:
|
|
893
1415
|
value.update_parameters_name(name + '.')
|
|
894
|
-
|
|
1416
|
+
self.insert_child_to_cell(name, value)
|
|
895
1417
|
if hasattr(self, '_cell_init_args'):
|
|
896
1418
|
self.cell_init_args += str({name: value})
|
|
897
1419
|
|
|
@@ -904,30 +1426,57 @@ class Cell(Cell_):
|
|
|
904
1426
|
else:
|
|
905
1427
|
self.insert_param_to_cell(name, None)
|
|
906
1428
|
|
|
907
|
-
def
|
|
908
|
-
|
|
1429
|
+
def _set_attr_for_object(self, name, value):
|
|
1430
|
+
"""Set attr for py object."""
|
|
909
1431
|
params = self.__dict__.get('_params')
|
|
910
|
-
if
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
1432
|
+
if params is not None and name in params:
|
|
1433
|
+
if value is not None:
|
|
1434
|
+
if isinstance(value, Tensor):
|
|
1435
|
+
params[name].set_data(value)
|
|
1436
|
+
return
|
|
1437
|
+
raise TypeError(
|
|
1438
|
+
f"Parameter '{name}' already exists in network, "
|
|
1439
|
+
f"can not assign this type: '{type(value)}' as a parameter.")
|
|
1440
|
+
params[name] = None
|
|
1441
|
+
return
|
|
1442
|
+
cells = self.__dict__.get('_cells')
|
|
1443
|
+
if cells is not None and name in cells:
|
|
1444
|
+
if value is not None:
|
|
1445
|
+
raise TypeError(
|
|
1446
|
+
f"Sub cell '{name}' already exists in network, "
|
|
1447
|
+
f"can not assign this type: '{type(value)}' as a cell.")
|
|
1448
|
+
cells[name] = None
|
|
1449
|
+
return
|
|
1450
|
+
buffers = self.__dict__.get('_buffers')
|
|
1451
|
+
if buffers is not None and name in buffers:
|
|
1452
|
+
if value is not None:
|
|
1453
|
+
raise TypeError(
|
|
1454
|
+
f"Buffer '{name}' already exists in network, "
|
|
1455
|
+
f"can not assign this type: '{type(value)}' as a buffer.")
|
|
1456
|
+
buffers[name] = None
|
|
1457
|
+
return
|
|
1458
|
+
object.__setattr__(self, name, value)
|
|
1459
|
+
|
|
1460
|
+
def __setattr__(self, name, value):
|
|
1461
|
+
if isinstance(value, (Parameter, ParameterTuple)):
|
|
1462
|
+
self._set_attr_for_param_or_param_tuple(name, value)
|
|
1463
|
+
elif _is_parameter_list_or_tuple(value):
|
|
915
1464
|
self._set_attr_for_parameter_in_list_or_tuple(name, value)
|
|
916
1465
|
elif isinstance(value, Cell):
|
|
917
1466
|
self._set_attr_for_cell(name, value)
|
|
918
|
-
elif
|
|
919
|
-
self.
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
if isinstance(value, Primitive):
|
|
926
|
-
value.set_prim_instance_name(name)
|
|
927
|
-
self._primitives[name] = value
|
|
1467
|
+
elif isinstance(value, _Buffer):
|
|
1468
|
+
if name in self.__dict__:
|
|
1469
|
+
del self.__dict__[name]
|
|
1470
|
+
self.register_buffer(name, value)
|
|
1471
|
+
elif isinstance(value, Primitive):
|
|
1472
|
+
value.set_prim_instance_name(name)
|
|
1473
|
+
self._primitives[name] = value
|
|
928
1474
|
object.__setattr__(self, name, value)
|
|
929
|
-
|
|
930
|
-
self.
|
|
1475
|
+
else:
|
|
1476
|
+
self._set_attr_for_object(name, value)
|
|
1477
|
+
|
|
1478
|
+
def _get_name(self):
|
|
1479
|
+
return self.__class__.__name__
|
|
931
1480
|
|
|
932
1481
|
def extend_repr(self):
|
|
933
1482
|
"""
|
|
@@ -941,19 +1490,28 @@ class Cell(Cell_):
|
|
|
941
1490
|
return self.__repr__()
|
|
942
1491
|
|
|
943
1492
|
def __repr__(self):
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
1493
|
+
extra_lines = []
|
|
1494
|
+
extend_repr = self.extend_repr()
|
|
1495
|
+
# empty string will be split into list ['']
|
|
1496
|
+
if extend_repr:
|
|
1497
|
+
extra_lines = extend_repr.split("\n")
|
|
1498
|
+
child_lines = []
|
|
1499
|
+
for key, cell in self._cells.items():
|
|
1500
|
+
cell_str = repr(cell)
|
|
1501
|
+
cell_str = _addindent(cell_str, 2)
|
|
1502
|
+
child_lines.append("(" + key + "): " + cell_str)
|
|
1503
|
+
lines = extra_lines + child_lines
|
|
1504
|
+
|
|
1505
|
+
main_str = self._get_name() + "("
|
|
1506
|
+
if lines:
|
|
1507
|
+
# simple one-liner info, which most builtin Modules will use
|
|
1508
|
+
if len(extra_lines) == 1 and not child_lines:
|
|
1509
|
+
main_str += extra_lines[0]
|
|
1510
|
+
else:
|
|
1511
|
+
main_str += "\n " + "\n ".join(lines) + "\n"
|
|
1512
|
+
|
|
1513
|
+
main_str += ")"
|
|
1514
|
+
return main_str
|
|
957
1515
|
|
|
958
1516
|
def load_parameter_slice(self, params):
|
|
959
1517
|
"""
|
|
@@ -1119,9 +1677,11 @@ class Cell(Cell_):
|
|
|
1119
1677
|
args (tuple): Args of the Cell object.
|
|
1120
1678
|
kwargs (dict): Kwargs of the Cell object.
|
|
1121
1679
|
"""
|
|
1680
|
+
_init_auto_parallel_context(self)
|
|
1122
1681
|
self._compile_args = self._get_compile_args(args)
|
|
1123
1682
|
_cell_graph_executor.compile(self, *self._compile_args, phase=self.phase,
|
|
1124
1683
|
jit_config_dict=self._jit_config_dict, **kwargs)
|
|
1684
|
+
_clear_auto_parallel_context(self)
|
|
1125
1685
|
|
|
1126
1686
|
def compile_and_run(self, *args, **kwargs):
|
|
1127
1687
|
"""
|
|
@@ -1252,9 +1812,9 @@ class Cell(Cell_):
|
|
|
1252
1812
|
>>> net2 = nn.Dense(2, 2)
|
|
1253
1813
|
>>> net1.insert_child_to_cell("child", net2)
|
|
1254
1814
|
>>> print(net1)
|
|
1255
|
-
ReLU
|
|
1256
|
-
(child): Dense
|
|
1257
|
-
|
|
1815
|
+
ReLU(
|
|
1816
|
+
(child): Dense(input_channels=2, output_channels=2, has_bias=True)
|
|
1817
|
+
)
|
|
1258
1818
|
"""
|
|
1259
1819
|
if not isinstance(child_name, str):
|
|
1260
1820
|
raise TypeError(f"For 'insert_child_to_cell', the type of parameter 'child_name' must be str, "
|
|
@@ -1312,13 +1872,22 @@ class Cell(Cell_):
|
|
|
1312
1872
|
new_param_tuple.append(param)
|
|
1313
1873
|
cell.__dict__[key] = ParameterTuple(new_param_tuple)
|
|
1314
1874
|
|
|
1875
|
+
def _get_cell_parallel_mode(self):
|
|
1876
|
+
"""Determine whether the current cell is in parallel mode."""
|
|
1877
|
+
is_parallel_mode = False
|
|
1878
|
+
for _, param in self.parameters_and_names():
|
|
1879
|
+
if param.param_info.is_param_init:
|
|
1880
|
+
is_parallel_mode = True
|
|
1881
|
+
break
|
|
1882
|
+
return is_parallel_mode
|
|
1883
|
+
|
|
1315
1884
|
def init_parameters_data(self, auto_parallel_mode=False):
|
|
1316
1885
|
"""
|
|
1317
1886
|
Initialize all parameters and replace the original saved parameters in cell.
|
|
1318
1887
|
|
|
1319
1888
|
Note:
|
|
1320
1889
|
trainable_params() and other similar interfaces may return different parameter instance after
|
|
1321
|
-
`init_parameters_data
|
|
1890
|
+
`init_parameters_data`. It is not recommended to save these results.
|
|
1322
1891
|
|
|
1323
1892
|
Args:
|
|
1324
1893
|
auto_parallel_mode (bool): If running in auto_parallel_mode. Default: ``False`` .
|
|
@@ -1350,15 +1919,24 @@ class Cell(Cell_):
|
|
|
1350
1919
|
def _updata(param):
|
|
1351
1920
|
if param in replace:
|
|
1352
1921
|
return replace.get(param)
|
|
1353
|
-
new_p = param.init_data(None, set_sliced=
|
|
1922
|
+
new_p = param.init_data(None, set_sliced=param.sliced)
|
|
1354
1923
|
replace[param] = new_p
|
|
1355
1924
|
return new_p
|
|
1356
1925
|
|
|
1357
1926
|
# replace all original usage.
|
|
1358
1927
|
cells = self.cells_and_names()
|
|
1928
|
+
is_parallel_mode = self._get_cell_parallel_mode()
|
|
1929
|
+
is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
|
|
1930
|
+
|
|
1359
1931
|
for _, cell in cells:
|
|
1360
1932
|
params = cell._params.items()
|
|
1361
1933
|
for param_name, param in params:
|
|
1934
|
+
not_sliced = not param.sliced
|
|
1935
|
+
judgment = not_sliced
|
|
1936
|
+
if param.param_info.is_pipeline_shared_param:
|
|
1937
|
+
continue
|
|
1938
|
+
if is_graph_mode and is_parallel_mode and judgment:
|
|
1939
|
+
continue
|
|
1362
1940
|
if not auto_parallel_mode:
|
|
1363
1941
|
cell._params[param_name] = _updata(param)
|
|
1364
1942
|
continue
|
|
@@ -1370,6 +1948,12 @@ class Cell(Cell_):
|
|
|
1370
1948
|
param_tuple = cell_dict[key]
|
|
1371
1949
|
new_param_tuple = []
|
|
1372
1950
|
for param in param_tuple:
|
|
1951
|
+
not_sliced = not param.sliced
|
|
1952
|
+
judgment = not_sliced
|
|
1953
|
+
if param.param_info.is_pipeline_shared_param:
|
|
1954
|
+
continue
|
|
1955
|
+
if is_graph_mode and is_parallel_mode and judgment:
|
|
1956
|
+
continue
|
|
1373
1957
|
if not auto_parallel_mode:
|
|
1374
1958
|
new_param_tuple.append(_updata(param))
|
|
1375
1959
|
continue
|
|
@@ -1677,7 +2261,7 @@ class Cell(Cell_):
|
|
|
1677
2261
|
... return x
|
|
1678
2262
|
>>> net = Net()
|
|
1679
2263
|
>>> print(net.cells())
|
|
1680
|
-
odict_values([Dense
|
|
2264
|
+
odict_values([Dense(input_channels=2, output_channels=2, has_bias=True)])
|
|
1681
2265
|
"""
|
|
1682
2266
|
return self.name_cells().values()
|
|
1683
2267
|
|
|
@@ -1738,7 +2322,7 @@ class Cell(Cell_):
|
|
|
1738
2322
|
... return x
|
|
1739
2323
|
>>> net = Net()
|
|
1740
2324
|
>>> print(net.name_cells())
|
|
1741
|
-
OrderedDict([('dense', Dense
|
|
2325
|
+
OrderedDict([('dense', Dense(input_channels=2, output_channels=2, has_bias=True))])
|
|
1742
2326
|
"""
|
|
1743
2327
|
value_set = set()
|
|
1744
2328
|
cells = OrderedDict()
|
|
@@ -1779,10 +2363,10 @@ class Cell(Cell_):
|
|
|
1779
2363
|
... if isinstance(cell, nn.Dense):
|
|
1780
2364
|
... cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype))
|
|
1781
2365
|
>>> net.apply(func)
|
|
1782
|
-
SequentialCell
|
|
1783
|
-
(0): Dense
|
|
1784
|
-
(1): Dense
|
|
1785
|
-
|
|
2366
|
+
SequentialCell(
|
|
2367
|
+
(0): Dense(input_channels=2, output_channels=2, has_bias=True)
|
|
2368
|
+
(1): Dense(input_channels=2, output_channels=2, has_bias=True)
|
|
2369
|
+
)
|
|
1786
2370
|
>>> print(net[0].weight.asnumpy())
|
|
1787
2371
|
[[1. 1.]
|
|
1788
2372
|
[1. 1.]]
|
|
@@ -1914,8 +2498,8 @@ class Cell(Cell_):
|
|
|
1914
2498
|
>>>
|
|
1915
2499
|
>>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
|
|
1916
2500
|
>>> net.to_float(mstype.float16)
|
|
1917
|
-
Conv2d
|
|
1918
|
-
padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=None, format=NCHW
|
|
2501
|
+
Conv2d(input_channels=120, output_channels=240, kernel_size=(4, 4), stride=(1, 1), pad_mode=same,
|
|
2502
|
+
padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=None, format=NCHW)
|
|
1919
2503
|
"""
|
|
1920
2504
|
if dst_type not in (mstype.float16, mstype.float32, mstype.bfloat16):
|
|
1921
2505
|
raise ValueError("For 'to_float', the argument 'dst_type' must be mstype.float32, mstype.float16 or "
|
|
@@ -1955,9 +2539,8 @@ class Cell(Cell_):
|
|
|
1955
2539
|
|
|
1956
2540
|
def set_grad(self, requires_grad=True):
|
|
1957
2541
|
"""
|
|
1958
|
-
Sets the cell flag for gradient.
|
|
1959
|
-
|
|
1960
|
-
network is executed.
|
|
2542
|
+
Sets the cell flag for gradient.
|
|
2543
|
+
|
|
1961
2544
|
|
|
1962
2545
|
Args:
|
|
1963
2546
|
requires_grad (bool): Specifies if the net need to grad, if it is
|
|
@@ -2121,8 +2704,7 @@ class Cell(Cell_):
|
|
|
2121
2704
|
"""
|
|
2122
2705
|
if context._get_mode() == context.GRAPH_MODE:
|
|
2123
2706
|
return HookHandle()
|
|
2124
|
-
|
|
2125
|
-
return HookHandle()
|
|
2707
|
+
check_hook_fn(hook_fn)
|
|
2126
2708
|
handle = HookHandle(self._forward_pre_hook)
|
|
2127
2709
|
self._forward_pre_hook[handle.handle_id] = hook_fn
|
|
2128
2710
|
return handle
|
|
@@ -2217,10 +2799,11 @@ class Cell(Cell_):
|
|
|
2217
2799
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
|
|
2218
2800
|
value= [ 2.00000000e+00]))
|
|
2219
2801
|
"""
|
|
2220
|
-
if
|
|
2802
|
+
if self.has_bprop:
|
|
2221
2803
|
return HookHandle()
|
|
2222
|
-
if
|
|
2804
|
+
if context._get_mode() == context.GRAPH_MODE:
|
|
2223
2805
|
return HookHandle()
|
|
2806
|
+
check_hook_fn(hook_fn)
|
|
2224
2807
|
handle = HookHandle(self._forward_hook)
|
|
2225
2808
|
self._forward_hook[handle.handle_id] = hook_fn
|
|
2226
2809
|
return handle
|
|
@@ -2310,8 +2893,7 @@ class Cell(Cell_):
|
|
|
2310
2893
|
"""
|
|
2311
2894
|
if context._get_mode() == context.GRAPH_MODE:
|
|
2312
2895
|
return HookHandle()
|
|
2313
|
-
|
|
2314
|
-
return HookHandle()
|
|
2896
|
+
check_hook_fn(hook_fn)
|
|
2315
2897
|
handle = HookHandle(self._backward_pre_hook)
|
|
2316
2898
|
self._backward_pre_hook[handle.handle_id] = hook_fn
|
|
2317
2899
|
if self._cell_backward_pre_hook is None:
|
|
@@ -2334,9 +2916,12 @@ class Cell(Cell_):
|
|
|
2334
2916
|
Supported Platforms:
|
|
2335
2917
|
``Ascend`` ``GPU`` ``CPU``
|
|
2336
2918
|
"""
|
|
2337
|
-
ret = self._cell_backward_pre_hook(outputs)
|
|
2338
2919
|
if isinstance(outputs, tuple):
|
|
2339
|
-
|
|
2920
|
+
ret = self._cell_backward_pre_hook(*outputs)
|
|
2921
|
+
else:
|
|
2922
|
+
ret = self._cell_backward_pre_hook(outputs)
|
|
2923
|
+
if isinstance(outputs, tuple):
|
|
2924
|
+
if len(outputs) == 1:
|
|
2340
2925
|
ret = (ret,)
|
|
2341
2926
|
if len(ret) != len(outputs):
|
|
2342
2927
|
raise TypeError(
|
|
@@ -2344,6 +2929,527 @@ class Cell(Cell_):
|
|
|
2344
2929
|
len(ret), len(outputs)))
|
|
2345
2930
|
return ret
|
|
2346
2931
|
|
|
2932
|
+
def get_extra_state(self) -> Any:
|
|
2933
|
+
"""Return any extra state to include in the cell's state_dict.
|
|
2934
|
+
|
|
2935
|
+
This function is called from ``state_dict``.
|
|
2936
|
+
Implement this and a corresponding ``set_extra_state`` for your cell
|
|
2937
|
+
if you need to store extra state.
|
|
2938
|
+
|
|
2939
|
+
Note that extra state should be picklable to ensure working serialization
|
|
2940
|
+
of the state_dict. Only provide backwards compatibility guarantees
|
|
2941
|
+
for serializing tensors; other objects may break backwards compatibility if
|
|
2942
|
+
their serialized pickled form changes.
|
|
2943
|
+
|
|
2944
|
+
Returns:
|
|
2945
|
+
object, any extra state to store in the cell's state_dict.
|
|
2946
|
+
"""
|
|
2947
|
+
raise RuntimeError(
|
|
2948
|
+
"Reached a code path in Cell.get_extra_state() that should never be called."
|
|
2949
|
+
|
|
2950
|
+
)
|
|
2951
|
+
|
|
2952
|
+
def set_extra_state(self, state: Any) -> None:
|
|
2953
|
+
"""Set extra state contained in the loaded `state_dict`.
|
|
2954
|
+
|
|
2955
|
+
This function is called from `load_state_dict` to handle any extra state
|
|
2956
|
+
found within the `state_dict`. Implement this function and a corresponding
|
|
2957
|
+
`get_extra_state` for your cell if you need to store extra state within its
|
|
2958
|
+
`state_dict`.
|
|
2959
|
+
|
|
2960
|
+
Args:
|
|
2961
|
+
state (dict): Extra state from the `state_dict`.
|
|
2962
|
+
"""
|
|
2963
|
+
raise RuntimeError(
|
|
2964
|
+
"Reached a code path in Cell.set_extra_state() that should never be called."
|
|
2965
|
+
)
|
|
2966
|
+
|
|
2967
|
+
@jit_forbidden_register
|
|
2968
|
+
def register_state_dict_post_hook(self, hook):
|
|
2969
|
+
r"""Register a post-hook for the :func:`mindspore.nn.Cell.state_dict` method.
|
|
2970
|
+
|
|
2971
|
+
It should have the following signature:
|
|
2972
|
+
|
|
2973
|
+
hook(cell, state_dict, prefix, local_metadata) -> None
|
|
2974
|
+
|
|
2975
|
+
The registered hooks can modify the ``state_dict`` inplace.
|
|
2976
|
+
|
|
2977
|
+
Args:
|
|
2978
|
+
hook (Callable): The hook function after `state_dict` is called.
|
|
2979
|
+
|
|
2980
|
+
Returns:
|
|
2981
|
+
A handle that can be used to remove the added hook by calling
|
|
2982
|
+
`handle.remove()`.
|
|
2983
|
+
"""
|
|
2984
|
+
from mindspore.utils.hooks import _RemovableHandle
|
|
2985
|
+
handle = _RemovableHandle(self._state_dict_hooks)
|
|
2986
|
+
self._state_dict_hooks[handle.id] = hook
|
|
2987
|
+
return handle
|
|
2988
|
+
|
|
2989
|
+
@jit_forbidden_register
|
|
2990
|
+
def register_state_dict_pre_hook(self, hook):
|
|
2991
|
+
r"""Register a pre-hook for the :func:`mindspore.nn.Cell.state_dict` method.
|
|
2992
|
+
|
|
2993
|
+
It should have the following signature:
|
|
2994
|
+
|
|
2995
|
+
hook(cell, prefix, keep_vars) -> None
|
|
2996
|
+
|
|
2997
|
+
The registered hooks can be used to perform pre-processing before the `state_dict`
|
|
2998
|
+
call is made.
|
|
2999
|
+
|
|
3000
|
+
Args:
|
|
3001
|
+
hook (Callable): The hook function before `state_dict` is called.
|
|
3002
|
+
|
|
3003
|
+
Returns:
|
|
3004
|
+
A handle that can be used to remove the added hook by calling
|
|
3005
|
+
`handle.remove()`.
|
|
3006
|
+
|
|
3007
|
+
Examples:
|
|
3008
|
+
>>> import mindspore
|
|
3009
|
+
...
|
|
3010
|
+
...
|
|
3011
|
+
>>> class NetA(mindspore.nn.Cell):
|
|
3012
|
+
... def __init__(self):
|
|
3013
|
+
... super().__init__()
|
|
3014
|
+
... self.register_buffer("buffer_a", mindspore.tensor([1, 2, 3]))
|
|
3015
|
+
... self.param_a = mindspore.Parameter(mindspore.tensor([1, 2, 3]))
|
|
3016
|
+
...
|
|
3017
|
+
... def construct(self, x):
|
|
3018
|
+
... return x + self.buffer_a + self.param_a
|
|
3019
|
+
...
|
|
3020
|
+
...
|
|
3021
|
+
>>> def _add_extra_param(cell, prefix, keep_vars):
|
|
3022
|
+
... cell._params["extra_param"] = mindspore.Parameter(mindspore.tensor([4, 5, 6]))
|
|
3023
|
+
...
|
|
3024
|
+
...
|
|
3025
|
+
>>> net = NetA()
|
|
3026
|
+
>>> handle = net.register_state_dict_pre_hook(_add_extra_param)
|
|
3027
|
+
>>> net_state_dict = net.state_dict()
|
|
3028
|
+
>>> handle.remove()
|
|
3029
|
+
>>> print("extra_param" in net_state_dict)
|
|
3030
|
+
True
|
|
3031
|
+
"""
|
|
3032
|
+
from mindspore.utils.hooks import _RemovableHandle
|
|
3033
|
+
handle = _RemovableHandle(self._state_dict_pre_hooks)
|
|
3034
|
+
self._state_dict_pre_hooks[handle.id] = hook
|
|
3035
|
+
return handle
|
|
3036
|
+
|
|
3037
|
+
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
3038
|
+
r"""Save cell state to the `destination` dictionary.
|
|
3039
|
+
|
|
3040
|
+
The `destination` dictionary will contain the state
|
|
3041
|
+
of the cell, but not its descendants. This is called on every
|
|
3042
|
+
sub cell in :func:`mindspore.nn.Cell.state_dict`.
|
|
3043
|
+
|
|
3044
|
+
In rare cases, subclasses can achieve class-specific behavior by
|
|
3045
|
+
overriding this method with custom logic.
|
|
3046
|
+
|
|
3047
|
+
Args:
|
|
3048
|
+
destination (dict): a dict where state will be stored
|
|
3049
|
+
prefix (str): the prefix for parameters and buffers used in this
|
|
3050
|
+
cell
|
|
3051
|
+
"""
|
|
3052
|
+
for name, param in self._params.items():
|
|
3053
|
+
if param is not None:
|
|
3054
|
+
destination[prefix + name] = param
|
|
3055
|
+
for name, buf in self._buffers.items():
|
|
3056
|
+
if buf is not None and name not in self._non_persistent_buffers_set:
|
|
3057
|
+
destination[prefix + name] = buf
|
|
3058
|
+
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
|
3059
|
+
if (
|
|
3060
|
+
getattr(self.__class__, "get_extra_state", Cell.get_extra_state)
|
|
3061
|
+
is not Cell.get_extra_state
|
|
3062
|
+
):
|
|
3063
|
+
destination[extra_state_key] = self.get_extra_state()
|
|
3064
|
+
|
|
3065
|
+
# The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
|
|
3066
|
+
# back that same object. But if they pass nothing, an `OrderedDict` is created and returned.
|
|
3067
|
+
T_destination = TypeVar("T_destination", bound=Dict[str, Any])
|
|
3068
|
+
|
|
3069
|
+
@jit_forbidden_register
|
|
3070
|
+
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
|
3071
|
+
r"""Return a dictionary containing references to the whole state of the cell.
|
|
3072
|
+
|
|
3073
|
+
Both parameters and persistent buffers (e.g. running averages) are
|
|
3074
|
+
included. Keys are corresponding parameter and buffer names.
|
|
3075
|
+
Parameters and buffers set to ``None`` are not included.
|
|
3076
|
+
|
|
3077
|
+
.. note::
|
|
3078
|
+
The returned object is a shallow copy. It contains references
|
|
3079
|
+
to the cell's parameters and buffers.
|
|
3080
|
+
|
|
3081
|
+
.. warning::
|
|
3082
|
+
- Currently ``state_dict()`` also accepts positional arguments for
|
|
3083
|
+
``destination``, ``prefix`` and ``keep_vars`` in order. However,
|
|
3084
|
+
this is being deprecated and keyword arguments will be enforced in
|
|
3085
|
+
future releases.
|
|
3086
|
+
|
|
3087
|
+
- Please avoid the use of argument ``destination`` as it is not
|
|
3088
|
+
designed for end-users.
|
|
3089
|
+
|
|
3090
|
+
Args:
|
|
3091
|
+
destination (dict, optional): If provided, the state of cell will
|
|
3092
|
+
be updated into the dict and the same object is returned.
|
|
3093
|
+
Otherwise, an ``OrderedDict`` will be created and returned.
|
|
3094
|
+
Default: ``None``.
|
|
3095
|
+
prefix (str, optional): A prefix added to parameter and buffer
|
|
3096
|
+
names to compose the keys in state_dict. Default: ``''``.
|
|
3097
|
+
keep_vars (bool, optional): Whether the state_dict returns a copy. Default: ``False`` , returns a reference.
|
|
3098
|
+
|
|
3099
|
+
Returns:
|
|
3100
|
+
Dict, a dictionary containing a whole state of the cell.
|
|
3101
|
+
|
|
3102
|
+
Examples:
|
|
3103
|
+
>>> import mindspore
|
|
3104
|
+
>>> class Model(mindspore.nn.Cell):
|
|
3105
|
+
... def __init__(self):
|
|
3106
|
+
... super().__init__()
|
|
3107
|
+
... self.register_buffer("buffer_a", mindspore.tensor([4, 5, 6]))
|
|
3108
|
+
... self.param_a = mindspore.Parameter(mindspore.tensor([1, 2, 3]))
|
|
3109
|
+
...
|
|
3110
|
+
... def construct(self, x):
|
|
3111
|
+
... return x + self.buffer_a + self.param_a
|
|
3112
|
+
...
|
|
3113
|
+
...
|
|
3114
|
+
>>> model = Model()
|
|
3115
|
+
>>> print(model.state_dict())
|
|
3116
|
+
OrderedDict([('param_a', Parameter (name=param_a, shape=(3,), dtype=Int64, requires_grad=True)), \
|
|
3117
|
+
('buffer_a', Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]))])
|
|
3118
|
+
"""
|
|
3119
|
+
# TODO: Remove `args` and the parsing logic when BC allows.
|
|
3120
|
+
if args:
|
|
3121
|
+
# DeprecationWarning is ignored by default
|
|
3122
|
+
warnings.warn(
|
|
3123
|
+
"Positional args are being deprecated, use kwargs instead. Refer to "
|
|
3124
|
+
"https://www.mindspore.cn/docs/zh-CN/master/api_python/nn/mindspore.nn.Cell.html"
|
|
3125
|
+
" for details.",
|
|
3126
|
+
FutureWarning,
|
|
3127
|
+
stacklevel=2,
|
|
3128
|
+
)
|
|
3129
|
+
if destination is None:
|
|
3130
|
+
destination = args[0]
|
|
3131
|
+
if len(args) > 1 and prefix == "":
|
|
3132
|
+
prefix = args[1]
|
|
3133
|
+
if len(args) > 2 and keep_vars is False:
|
|
3134
|
+
keep_vars = args[2]
|
|
3135
|
+
if destination is not None and not isinstance(destination, dict):
|
|
3136
|
+
raise TypeError(f"The type of destination must be OrderedDict, but got {type(destination)}")
|
|
3137
|
+
if not isinstance(prefix, str):
|
|
3138
|
+
raise TypeError(f"The type of prefix must be string, but got {type(prefix)}")
|
|
3139
|
+
if not isinstance(keep_vars, bool):
|
|
3140
|
+
raise TypeError(f"The type of keep_vars must be bool, but got {type(keep_vars)}")
|
|
3141
|
+
|
|
3142
|
+
if destination is None:
|
|
3143
|
+
destination = OrderedDict()
|
|
3144
|
+
destination._metadata = OrderedDict()
|
|
3145
|
+
|
|
3146
|
+
local_metadata = {}
|
|
3147
|
+
if hasattr(destination, "_metadata"):
|
|
3148
|
+
destination._metadata[prefix[:-1]] = local_metadata
|
|
3149
|
+
|
|
3150
|
+
for hook in self._state_dict_pre_hooks.values():
|
|
3151
|
+
hook(self, prefix, keep_vars)
|
|
3152
|
+
self._save_to_state_dict(destination, prefix, keep_vars)
|
|
3153
|
+
for name, cell in self._cells.items():
|
|
3154
|
+
if cell is not None:
|
|
3155
|
+
cell.state_dict(
|
|
3156
|
+
destination=destination,
|
|
3157
|
+
prefix=prefix + name + ".",
|
|
3158
|
+
keep_vars=keep_vars,
|
|
3159
|
+
)
|
|
3160
|
+
for hook in self._state_dict_hooks.values():
|
|
3161
|
+
hook_result = hook(self, destination, prefix, local_metadata)
|
|
3162
|
+
if hook_result is not None:
|
|
3163
|
+
raise RuntimeError("state_dict post-hook must return None")
|
|
3164
|
+
return destination
|
|
3165
|
+
|
|
3166
|
+
@jit_forbidden_register
|
|
3167
|
+
def register_load_state_dict_pre_hook(self, hook):
|
|
3168
|
+
r"""Register a pre-hook to be run before cell's :func:`mindspore.nn.Cell.load_state_dict` is called.
|
|
3169
|
+
|
|
3170
|
+
It should have the following signature:
|
|
3171
|
+
|
|
3172
|
+
hook(cell, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950
|
|
3173
|
+
|
|
3174
|
+
Args:
|
|
3175
|
+
hook (Callable): The hook function before `load_state_dict` is called.
|
|
3176
|
+
|
|
3177
|
+
Returns:
|
|
3178
|
+
A handle that can be used to remove the added hook by calling
|
|
3179
|
+
`handle.remove()`.
|
|
3180
|
+
"""
|
|
3181
|
+
from mindspore.utils.hooks import _RemovableHandle
|
|
3182
|
+
handle = _RemovableHandle(self._load_state_dict_pre_hooks)
|
|
3183
|
+
self._load_state_dict_pre_hooks[handle.id] = hook
|
|
3184
|
+
return handle
|
|
3185
|
+
|
|
3186
|
+
@jit_forbidden_register
|
|
3187
|
+
def register_load_state_dict_post_hook(self, hook):
|
|
3188
|
+
r"""Register a post-hook to be run after cell's :func:`mindspore.nn.Cell.load_state_dict` is called.
|
|
3189
|
+
|
|
3190
|
+
It should have the following signature:
|
|
3191
|
+
|
|
3192
|
+
hook(cell, incompatible_keys) -> None
|
|
3193
|
+
|
|
3194
|
+
The ``cell`` argument is the current cell that this hook is registered
|
|
3195
|
+
on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting
|
|
3196
|
+
of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys``
|
|
3197
|
+
is a ``list`` of ``str`` containing the missing keys and
|
|
3198
|
+
``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys.
|
|
3199
|
+
|
|
3200
|
+
The given incompatible_keys can be modified inplace if needed.
|
|
3201
|
+
|
|
3202
|
+
Note that the checks performed when calling :func:`load_state_dict` with
|
|
3203
|
+
``strict=True`` are affected by modifications the hook makes to
|
|
3204
|
+
``missing_keys`` or ``unexpected_keys``, as expected. Additions to either
|
|
3205
|
+
set of keys will result in an error being thrown when ``strict=True``, and
|
|
3206
|
+
clearing out both missing and unexpected keys will avoid an error.
|
|
3207
|
+
|
|
3208
|
+
Args:
|
|
3209
|
+
hook (Callable): The hook function after `load_state_dict` is called.
|
|
3210
|
+
|
|
3211
|
+
Returns:
|
|
3212
|
+
A handle that can be used to remove the added hook by calling
|
|
3213
|
+
`handle.remove()`.
|
|
3214
|
+
"""
|
|
3215
|
+
from mindspore.utils.hooks import _RemovableHandle
|
|
3216
|
+
handle = _RemovableHandle(self._load_state_dict_post_hooks)
|
|
3217
|
+
self._load_state_dict_post_hooks[handle.id] = hook
|
|
3218
|
+
return handle
|
|
3219
|
+
|
|
3220
|
+
def _load_from_state_dict(
|
|
3221
|
+
self,
|
|
3222
|
+
state_dict,
|
|
3223
|
+
prefix,
|
|
3224
|
+
local_metadata,
|
|
3225
|
+
strict,
|
|
3226
|
+
missing_keys,
|
|
3227
|
+
unexpected_keys,
|
|
3228
|
+
error_msgs,
|
|
3229
|
+
):
|
|
3230
|
+
r"""Copy parameters and buffers from :attr:`state_dict` into only this cell, but not its descendants.
|
|
3231
|
+
|
|
3232
|
+
This is called on every sub cell
|
|
3233
|
+
in :func:`mindspore.nn.Cell.load_state_dict`. Metadata saved for this
|
|
3234
|
+
cell in input :attr:`state_dict` is provided as :attr:`local_metadata`.
|
|
3235
|
+
For state dicts without metadata, :attr:`local_metadata` is empty.
|
|
3236
|
+
Subclasses can achieve class-specific backward compatible loading using
|
|
3237
|
+
the version number at `local_metadata.get("version", None)`.
|
|
3238
|
+
|
|
3239
|
+
.. note::
|
|
3240
|
+
:attr:`state_dict` is not the same object as the input
|
|
3241
|
+
:attr:`state_dict` to :func:`mindspore.nn.Cell.load_state_dict`. So
|
|
3242
|
+
it can be modified.
|
|
3243
|
+
|
|
3244
|
+
Args:
|
|
3245
|
+
state_dict (dict): a dict containing parameters and
|
|
3246
|
+
persistent buffers.
|
|
3247
|
+
prefix (str): the prefix for parameters and buffers used in this
|
|
3248
|
+
cell
|
|
3249
|
+
local_metadata (dict): a dict containing the metadata for this cell.
|
|
3250
|
+
See
|
|
3251
|
+
strict (bool): whether to strictly enforce that the keys in
|
|
3252
|
+
:attr:`state_dict` with :attr:`prefix` match the names of
|
|
3253
|
+
parameters and buffers in this cell
|
|
3254
|
+
missing_keys (list of str): if ``strict=True``, add missing keys to
|
|
3255
|
+
this list
|
|
3256
|
+
unexpected_keys (list of str): if ``strict=True``, add unexpected
|
|
3257
|
+
keys to this list
|
|
3258
|
+
error_msgs (list of str): error messages should be added to this
|
|
3259
|
+
list, and will be reported together in
|
|
3260
|
+
:func:`mindspore.nn.Cell.load_state_dict`
|
|
3261
|
+
"""
|
|
3262
|
+
for hook in self._load_state_dict_pre_hooks.values():
|
|
3263
|
+
hook(
|
|
3264
|
+
self,
|
|
3265
|
+
state_dict,
|
|
3266
|
+
prefix,
|
|
3267
|
+
local_metadata,
|
|
3268
|
+
strict,
|
|
3269
|
+
missing_keys,
|
|
3270
|
+
unexpected_keys,
|
|
3271
|
+
error_msgs,
|
|
3272
|
+
)
|
|
3273
|
+
|
|
3274
|
+
persistent_buffers = {
|
|
3275
|
+
k: v
|
|
3276
|
+
for k, v in self._buffers.items()
|
|
3277
|
+
if k not in self._non_persistent_buffers_set
|
|
3278
|
+
}
|
|
3279
|
+
local_name_params = itertools.chain(
|
|
3280
|
+
self._params.items(), persistent_buffers.items()
|
|
3281
|
+
)
|
|
3282
|
+
local_state = {k: v for k, v in local_name_params if v is not None}
|
|
3283
|
+
|
|
3284
|
+
for name, param in local_state.items():
|
|
3285
|
+
key = prefix + name
|
|
3286
|
+
if key in state_dict:
|
|
3287
|
+
input_param = state_dict[key]
|
|
3288
|
+
if not isinstance(input_param, Tensor):
|
|
3289
|
+
error_msgs.append(
|
|
3290
|
+
f'While copying the parameter named "{key}", '
|
|
3291
|
+
"expected Tensor or Tensor-like object from checkpoint but "
|
|
3292
|
+
f"received {type(input_param)}"
|
|
3293
|
+
)
|
|
3294
|
+
continue
|
|
3295
|
+
|
|
3296
|
+
if input_param.shape != param.shape:
|
|
3297
|
+
# local shape should match the one in checkpoint
|
|
3298
|
+
error_msgs.append(
|
|
3299
|
+
f"size mismatch for {key}: copying a param with shape {input_param.shape} from checkpoint, "
|
|
3300
|
+
f"the shape in current model is {param.shape}."
|
|
3301
|
+
)
|
|
3302
|
+
continue
|
|
3303
|
+
try:
|
|
3304
|
+
param.assign_value(Tensor(input_param.asnumpy(), dtype=param.dtype))
|
|
3305
|
+
except Exception as ex: # pylint: disable=W0703
|
|
3306
|
+
error_msgs.append(
|
|
3307
|
+
f'While copy the parameter named "{key}", '
|
|
3308
|
+
f"whose shape in the model are {param.shape} and "
|
|
3309
|
+
f"whose shape in the checkpoint are {input_param.shape}, "
|
|
3310
|
+
f"an exception occurred : {ex.args}."
|
|
3311
|
+
)
|
|
3312
|
+
elif strict:
|
|
3313
|
+
missing_keys.append(key)
|
|
3314
|
+
|
|
3315
|
+
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
|
3316
|
+
if getattr(self.__class__, "set_extra_state", Cell.set_extra_state) is not Cell.set_extra_state:
|
|
3317
|
+
if extra_state_key in state_dict:
|
|
3318
|
+
self.set_extra_state(state_dict[extra_state_key])
|
|
3319
|
+
elif strict:
|
|
3320
|
+
missing_keys.append(extra_state_key)
|
|
3321
|
+
elif strict and (extra_state_key in state_dict):
|
|
3322
|
+
unexpected_keys.append(extra_state_key)
|
|
3323
|
+
|
|
3324
|
+
if strict:
|
|
3325
|
+
for key in state_dict.keys():
|
|
3326
|
+
if key.startswith(prefix) and key != extra_state_key:
|
|
3327
|
+
input_name = key[len(prefix):].split(".", 1)
|
|
3328
|
+
# Must be cell if it have attributes
|
|
3329
|
+
if len(input_name) > 1:
|
|
3330
|
+
if input_name[0] not in self._cells:
|
|
3331
|
+
unexpected_keys.append(key)
|
|
3332
|
+
elif input_name[0] not in local_state:
|
|
3333
|
+
unexpected_keys.append(key)
|
|
3334
|
+
|
|
3335
|
+
@jit_forbidden_register
|
|
3336
|
+
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
|
3337
|
+
r"""Copy parameters and buffers from :attr:`state_dict` into this cell and its descendants.
|
|
3338
|
+
|
|
3339
|
+
If :attr:`strict` is ``True``, then
|
|
3340
|
+
the keys of :attr:`state_dict` must exactly match the keys returned
|
|
3341
|
+
by this cell's :func:`mindspore.nn.Cell.state_dict` function.
|
|
3342
|
+
|
|
3343
|
+
Args:
|
|
3344
|
+
state_dict (dict): A dict containing parameters and
|
|
3345
|
+
persistent buffers.
|
|
3346
|
+
strict (bool, optional): Whether to strictly enforce that the keys
|
|
3347
|
+
in input `state_dict` match the keys returned by this cell's
|
|
3348
|
+
:func:`mindspore.nn.Cell.state_dict` function. Default ``True`` .
|
|
3349
|
+
|
|
3350
|
+
Returns:
|
|
3351
|
+
A namedtuple with ``missing_keys`` and ``unexpected_keys`` fields,
|
|
3352
|
+
|
|
3353
|
+
- `missing_keys` is a list of str containing any keys that are expected
|
|
3354
|
+
by this cell but missing from the provided ``state_dict``.
|
|
3355
|
+
|
|
3356
|
+
- `unexpected_keys` is a list of str containing the keys that are not
|
|
3357
|
+
expected by this cell but present in the provided ``state_dict``.
|
|
3358
|
+
|
|
3359
|
+
Note:
|
|
3360
|
+
If `strict` is ``True`` and a parameter or buffer is registered as ``None``, but its corresponding key
|
|
3361
|
+
exists in :attr:`state_dict`, and :func:`mindspore.nn.Cell.load_state_dict` will raise a ``RuntimeError``.
|
|
3362
|
+
|
|
3363
|
+
Examples:
|
|
3364
|
+
>>> import mindspore
|
|
3365
|
+
>>> import os
|
|
3366
|
+
>>> class Model(mindspore.nn.Cell):
|
|
3367
|
+
... def __init__(self):
|
|
3368
|
+
... super().__init__()
|
|
3369
|
+
... self.register_buffer("buffer_a", mindspore.tensor([4, 5, 6]))
|
|
3370
|
+
... self.param_a = mindspore.Parameter(mindspore.tensor([1, 2, 3]))
|
|
3371
|
+
...
|
|
3372
|
+
... def construct(self, x):
|
|
3373
|
+
... return x + self.buffer_a + self.param_a
|
|
3374
|
+
...
|
|
3375
|
+
...
|
|
3376
|
+
>>> model = Model()
|
|
3377
|
+
>>> print(model.state_dict())
|
|
3378
|
+
>>> mindspore.save_checkpoint(model.state_dict(), './model_state_dict_ckpt')
|
|
3379
|
+
>>> new_model = Model()
|
|
3380
|
+
>>> new_model.load_state_dict(mindspore.load_checkpoint('./model_state_dict_ckpt'))
|
|
3381
|
+
>>> print(new_model.state_dict())
|
|
3382
|
+
>>> os.remove('./model_state_dict_ckpt')
|
|
3383
|
+
OrderedDict([('param_a', Parameter (name=param_a, shape=(3,), dtype=Int64, requires_grad=True)), \
|
|
3384
|
+
('buffer_a', Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]))])
|
|
3385
|
+
OrderedDict([('param_a', Parameter (name=param_a, shape=(3,), dtype=Int64, requires_grad=True)), \
|
|
3386
|
+
('buffer_a', Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]))])
|
|
3387
|
+
"""
|
|
3388
|
+
if not isinstance(state_dict, Mapping):
|
|
3389
|
+
raise TypeError(
|
|
3390
|
+
f"Expected state_dict to be dict-like, got {type(state_dict)}."
|
|
3391
|
+
)
|
|
3392
|
+
|
|
3393
|
+
missing_keys: List[str] = []
|
|
3394
|
+
unexpected_keys: List[str] = []
|
|
3395
|
+
error_msgs: List[str] = []
|
|
3396
|
+
|
|
3397
|
+
# copy state_dict so _load_from_state_dict can modify it
|
|
3398
|
+
metadata = getattr(state_dict, "_metadata", None)
|
|
3399
|
+
state_dict = OrderedDict(state_dict)
|
|
3400
|
+
if metadata is not None:
|
|
3401
|
+
# mypy isn't aware that "_metadata" exists in state_dict
|
|
3402
|
+
state_dict._metadata = metadata # type: ignore[attr-defined]
|
|
3403
|
+
|
|
3404
|
+
def load(cell, local_state_dict, prefix=""):
|
|
3405
|
+
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
|
3406
|
+
cell._load_from_state_dict(
|
|
3407
|
+
local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs,
|
|
3408
|
+
)
|
|
3409
|
+
for name, child in cell._cells.items():
|
|
3410
|
+
if child is not None:
|
|
3411
|
+
child_prefix = prefix + name + "."
|
|
3412
|
+
child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
|
|
3413
|
+
load(child, child_state_dict, child_prefix) # noqa: F821
|
|
3414
|
+
|
|
3415
|
+
# Note that the hook can modify missing_keys and unexpected_keys.
|
|
3416
|
+
incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
|
|
3417
|
+
for hook in cell._load_state_dict_post_hooks.values():
|
|
3418
|
+
out = hook(cell, incompatible_keys)
|
|
3419
|
+
if out is not None:
|
|
3420
|
+
raise RuntimeError(
|
|
3421
|
+
"Hooks registered with ``register_load_state_dict_post_hook`` are not"
|
|
3422
|
+
"expected to return new values, if incompatible_keys need to be modified,"
|
|
3423
|
+
"it should be done inplace."
|
|
3424
|
+
)
|
|
3425
|
+
|
|
3426
|
+
load(self, state_dict)
|
|
3427
|
+
del load
|
|
3428
|
+
|
|
3429
|
+
if strict:
|
|
3430
|
+
if unexpected_keys:
|
|
3431
|
+
error_msgs.insert(
|
|
3432
|
+
0,
|
|
3433
|
+
"Unexpected key(s) in state_dict: {}. ".format(
|
|
3434
|
+
", ".join(f'"{k}"' for k in unexpected_keys)
|
|
3435
|
+
),
|
|
3436
|
+
)
|
|
3437
|
+
if missing_keys:
|
|
3438
|
+
error_msgs.insert(
|
|
3439
|
+
0,
|
|
3440
|
+
"Missing key(s) in state_dict: {}. ".format(
|
|
3441
|
+
", ".join(f'"{k}"' for k in missing_keys)
|
|
3442
|
+
),
|
|
3443
|
+
)
|
|
3444
|
+
|
|
3445
|
+
if error_msgs:
|
|
3446
|
+
raise RuntimeError(
|
|
3447
|
+
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
|
3448
|
+
self.__class__.__name__, "\n\t".join(error_msgs)
|
|
3449
|
+
)
|
|
3450
|
+
)
|
|
3451
|
+
return _IncompatibleKeys(missing_keys, unexpected_keys)
|
|
3452
|
+
|
|
2347
3453
|
def register_backward_hook(self, hook_fn):
|
|
2348
3454
|
"""
|
|
2349
3455
|
Register the backward hook function.
|
|
@@ -2403,8 +3509,7 @@ class Cell(Cell_):
|
|
|
2403
3509
|
"""
|
|
2404
3510
|
if context._get_mode() == context.GRAPH_MODE:
|
|
2405
3511
|
return HookHandle()
|
|
2406
|
-
|
|
2407
|
-
return HookHandle()
|
|
3512
|
+
check_hook_fn(hook_fn)
|
|
2408
3513
|
handle = HookHandle(self._backward_hook)
|
|
2409
3514
|
self._backward_hook[handle.handle_id] = hook_fn
|
|
2410
3515
|
if self._cell_backward_hook is None:
|
|
@@ -2452,9 +3557,14 @@ class Cell(Cell_):
|
|
|
2452
3557
|
outputs = self.construct(*outputs, **kwargs)
|
|
2453
3558
|
else:
|
|
2454
3559
|
outputs = self.construct(outputs, **kwargs)
|
|
2455
|
-
|
|
2456
|
-
|
|
2457
|
-
|
|
3560
|
+
if isinstance(outputs, tuple):
|
|
3561
|
+
new_outputs = self._cell_backward_hook(*outputs)
|
|
3562
|
+
else:
|
|
3563
|
+
new_outputs = self._cell_backward_hook(outputs)
|
|
3564
|
+
# if outputs is (X,) and new_outpus is X
|
|
3565
|
+
if isinstance(outputs, tuple) and len(outputs) == 1:
|
|
3566
|
+
new_outputs = (new_outputs,)
|
|
3567
|
+
return new_outputs
|
|
2458
3568
|
|
|
2459
3569
|
def set_param_ps(self, recurse=True, init_in_server=False):
|
|
2460
3570
|
"""
|
|
@@ -2543,8 +3653,9 @@ class Cell(Cell_):
|
|
|
2543
3653
|
if not self._has_config_recompute:
|
|
2544
3654
|
self._has_config_recompute = True
|
|
2545
3655
|
else:
|
|
2546
|
-
|
|
2547
|
-
|
|
3656
|
+
logger.info("The recompute interface can be configured only once."
|
|
3657
|
+
" When the parent cell is configured, the child cell should not be configured")
|
|
3658
|
+
return
|
|
2548
3659
|
self._set_recompute_scope(mode)
|
|
2549
3660
|
if mode and not output_recompute:
|
|
2550
3661
|
self.add_flags(output_no_recompute=True)
|
|
@@ -2584,18 +3695,13 @@ class Cell(Cell_):
|
|
|
2584
3695
|
"""
|
|
2585
3696
|
if context.get_context("mode") == context.PYNATIVE_MODE:
|
|
2586
3697
|
self._recompute_cell = recompute_registry.get()(self.construct)
|
|
2587
|
-
self._add_recompute_flag()
|
|
2588
|
-
return
|
|
2589
3698
|
self._recompute()
|
|
2590
3699
|
if 'mp_comm_recompute' in kwargs.keys():
|
|
2591
3700
|
self._mp_comm_recompute(kwargs.get('mp_comm_recompute', False))
|
|
2592
3701
|
if 'parallel_optimizer_comm_recompute' in kwargs.keys():
|
|
2593
|
-
if
|
|
2594
|
-
context.get_auto_parallel_context("pipeline_stages") > 1):
|
|
3702
|
+
if kwargs.get('parallel_optimizer_comm_recompute', False):
|
|
2595
3703
|
logger.warning("Currently, the communication operator allgathers introduced by optimizer shard "
|
|
2596
|
-
"
|
|
2597
|
-
elif context.get_auto_parallel_context("pipeline_stages") == 1:
|
|
2598
|
-
self._parallel_optimizer_comm_recompute(kwargs.get('parallel_optimizer_comm_recompute', False))
|
|
3704
|
+
"is replaced with zero3.")
|
|
2599
3705
|
if 'recompute_slice_activation' in kwargs:
|
|
2600
3706
|
self._recompute_slice_activation(kwargs.get('recompute_slice_activation', False))
|
|
2601
3707
|
|
|
@@ -2687,17 +3793,91 @@ class Cell(Cell_):
|
|
|
2687
3793
|
if hasattr(network, "_amp_level"):
|
|
2688
3794
|
self._amp_level = getattr(network, "_amp_level")
|
|
2689
3795
|
|
|
2690
|
-
def
|
|
2691
|
-
"""
|
|
2692
|
-
Set pynative cell recomputed.
|
|
3796
|
+
def _register_parameters_hook(self, forward_hook=None, backward_hook=None, all=False):
|
|
2693
3797
|
"""
|
|
2694
|
-
|
|
2695
|
-
|
|
3798
|
+
Register the forward hook for parameters and register the backward hook for the corresponding gradient.
|
|
3799
|
+
|
|
3800
|
+
.. warning::
|
|
3801
|
+
This is an experimental prototype that is subject to change and/or deletion.
|
|
3802
|
+
|
|
3803
|
+
Note:
|
|
3804
|
+
- The `_register_parameters_hook(forward_hook, backward_hook)` only work in graph mode
|
|
3805
|
+
- The `forward_hook` must be defined as the following code.
|
|
3806
|
+
`parameters`: the tuple of the trainble parameters of the Cell, each element in the tuple shuould be
|
|
3807
|
+
in the format of `(param_name, Parameter)`.
|
|
3808
|
+
- The `forward_hook` should have the following signature:
|
|
3809
|
+
forward_hook(parameters) -> None.
|
|
3810
|
+
- The `backward_hook` must be defined as the following code.
|
|
3811
|
+
`gradients`: the tuple of the gradients corresponding to the trainble parameters of the Cell, each
|
|
3812
|
+
element in the tuple shuould be in the format of `(param_name, gradient)`.
|
|
3813
|
+
- The `backward_hook` should have the following signature:
|
|
3814
|
+
backward_hook(parameters) -> New gradients.
|
|
3815
|
+
|
|
3816
|
+
Args:
|
|
3817
|
+
forward_hook (function, optional): Python function or ``None``, Forward hook function. Default: ``None``
|
|
3818
|
+
backward_hook (function, optional): Python function or ``None``, Backward hook function. Default ``None``
|
|
3819
|
+
all (bool, optional): bool, whether to set hooks for all sub cells recursively. Default: ``False``
|
|
3820
|
+
|
|
3821
|
+
Returns:
|
|
3822
|
+
None
|
|
3823
|
+
|
|
3824
|
+
Raises:
|
|
3825
|
+
RuntimeError: If the `forward_hook` or `backward_hook ` has unspoorted syntax under GRAPH MODE.
|
|
3826
|
+
TypeError: If the `forward_hook` or `backward_hook` is not defined as required.
|
|
3827
|
+
|
|
3828
|
+
Supported Platforms:
|
|
3829
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
3830
|
+
|
|
3831
|
+
Examples:
|
|
3832
|
+
>>> import mindspore as ms
|
|
3833
|
+
>>> from mindspore import Tensor, nn, ops, Parameter
|
|
3834
|
+
>>>
|
|
3835
|
+
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
3836
|
+
>>> def parameter_hook(parameters):
|
|
3837
|
+
... print("--- enter parameter hook ---")
|
|
3838
|
+
... for name, param in parameters:
|
|
3839
|
+
... print (name, param)
|
|
3840
|
+
... print("--- leave parameter hook ---")
|
|
3841
|
+
...
|
|
3842
|
+
>>> def gradient_hook(gradients):
|
|
3843
|
+
... print("--- enter gradient hook ---")
|
|
3844
|
+
... outs = []
|
|
3845
|
+
... for name, gradient in gradients:
|
|
3846
|
+
... print(name, gradient)
|
|
3847
|
+
... outs.append(gradient * 2) # double gradient
|
|
3848
|
+
... print("--- leave gradient hook ---")
|
|
3849
|
+
... return outs
|
|
3850
|
+
...
|
|
3851
|
+
>>> class Net(nn.Cell):
|
|
3852
|
+
... def __init__(self)
|
|
3853
|
+
... super(Net, self).__init__()
|
|
3854
|
+
... self.w = Parameter(Tensor(np.array([3.0], np.float32)), name='w')
|
|
3855
|
+
... def construct(self, x):
|
|
3856
|
+
... return self.w * x
|
|
3857
|
+
...
|
|
3858
|
+
>>> grad = ops.GradOperation(get_by_list=True)
|
|
3859
|
+
>>> net = Net()
|
|
3860
|
+
>>> net._register_parameters_hook(forward_hook=parameter_hook, backward_hook=gradient_hook)
|
|
3861
|
+
>>> x = Tensor(np.array([4.0]).astype(np.float32))
|
|
3862
|
+
>>> output = grad(net, net.trainable_params())(x)
|
|
3863
|
+
--- enter parameter hook ---
|
|
3864
|
+
w
|
|
3865
|
+
Tensor(shape=[1], dtype=Float32, value=[ 3.00000000e+00])
|
|
3866
|
+
--- leave parameter hook ---
|
|
3867
|
+
--- enter gradient hook ---
|
|
3868
|
+
w
|
|
3869
|
+
Tensor(shape=[1], dtype=Float32, value=[ 4.00000000e+00])
|
|
3870
|
+
--- leave gradient hook ---
|
|
3871
|
+
>>> print("doubled grad: ", output)
|
|
3872
|
+
doubled grad: (Tensor(shape=[1], dtype=Float32, value=[ 8.00000000e+00]),)
|
|
3873
|
+
"""
|
|
3874
|
+
if not all:
|
|
3875
|
+
self._parameters_forward_hook = forward_hook
|
|
3876
|
+
self._parameters_backward_hook = backward_hook
|
|
2696
3877
|
else:
|
|
2697
|
-
|
|
2698
|
-
|
|
2699
|
-
|
|
2700
|
-
cell._add_recompute_flag()
|
|
3878
|
+
for _, cell in self.cells_and_names():
|
|
3879
|
+
cell._parameters_forward_hook = forward_hook
|
|
3880
|
+
cell._parameters_backward_hook = backward_hook
|
|
2701
3881
|
|
|
2702
3882
|
|
|
2703
3883
|
class GraphCell(Cell):
|
|
@@ -2713,12 +3893,10 @@ class GraphCell(Cell):
|
|
|
2713
3893
|
The key is the parameter name whose type is str, and the value is a Tensor or Parameter.
|
|
2714
3894
|
If the parameter exists in the graph according to the name, update it's value.
|
|
2715
3895
|
If the parameter does not exist, ignore it. Default: ``None`` .
|
|
2716
|
-
obf_random_seed (Union[int, None]): The random seed used for dynamic obfuscation
|
|
2717
|
-
used for model protection, which can refer to :func:`mindspore.obfuscate_model`. If the input `graph` is
|
|
2718
|
-
a func_graph loaded from a mindir file obfuscated with `obf_random_seed` , then `obf_random_seed` should be
|
|
2719
|
-
provided. `obf_random_seed` should be in (0, 9223372036854775807]. default: ``None`` .
|
|
3896
|
+
obf_random_seed (Union[int, None]): The random seed used for dynamic obfuscation, which is not supported now.
|
|
2720
3897
|
|
|
2721
3898
|
Raises:
|
|
3899
|
+
NotImplementedError: Dynamic structure obfuscation is not supported now.
|
|
2722
3900
|
TypeError: If the `graph` is not a FuncGraph.
|
|
2723
3901
|
TypeError: If the `params_init` is not a dict.
|
|
2724
3902
|
TypeError: If the key of the `params_init` is not a str.
|
|
@@ -2748,20 +3926,12 @@ class GraphCell(Cell):
|
|
|
2748
3926
|
|
|
2749
3927
|
def __init__(self, graph, params_init=None, obf_random_seed=None):
|
|
2750
3928
|
super(GraphCell, self).__init__(auto_prefix=True)
|
|
3929
|
+
if obf_random_seed is not None:
|
|
3930
|
+
raise NotImplementedError("Dynamic structure obfuscation is not supported now.")
|
|
2751
3931
|
if not isinstance(graph, FuncGraph):
|
|
2752
3932
|
raise TypeError(f"For 'GraphCell', the argument 'graph' must be a FuncGraph loaded from MindIR, "
|
|
2753
3933
|
f"but got type {type(graph)}.")
|
|
2754
3934
|
self.graph = graph
|
|
2755
|
-
self.obf_random_seed = obf_random_seed
|
|
2756
|
-
if obf_random_seed is not None:
|
|
2757
|
-
if not isinstance(obf_random_seed, int):
|
|
2758
|
-
raise TypeError("'obf_random_seed' must be int, but got {}.".format(type(obf_random_seed)))
|
|
2759
|
-
int_64_max = 9223372036854775807
|
|
2760
|
-
if obf_random_seed <= 0 or obf_random_seed > int_64_max:
|
|
2761
|
-
raise ValueError(
|
|
2762
|
-
"'obf_random_seed' must be larger than 0, and less or equal than int64 ({}),"
|
|
2763
|
-
"but got {}.".format(int_64_max, obf_random_seed))
|
|
2764
|
-
self._branch_control_input = _generate_branch_control_input(self.obf_random_seed)
|
|
2765
3935
|
params_init = {} if params_init is None else params_init
|
|
2766
3936
|
if not isinstance(params_init, dict):
|
|
2767
3937
|
raise TypeError(f"For 'GraphCell', the argument 'params_init' must be a dict, but got {type(params_init)}.")
|
|
@@ -2781,19 +3951,30 @@ class GraphCell(Cell):
|
|
|
2781
3951
|
def __call__(self, *args, **kwargs):
|
|
2782
3952
|
self.phase = "graph_load_from_mindir"
|
|
2783
3953
|
self._add_attr("graph_load_from_mindir", self.graph)
|
|
2784
|
-
|
|
2785
|
-
return self.compile_and_run(*args, **kwargs)
|
|
2786
|
-
append_input = Tensor((numpy.ones((1,)) * self._branch_control_input).astype(numpy.int32))
|
|
2787
|
-
return self.compile_and_run(*args, append_input, **kwargs)
|
|
3954
|
+
return self.compile_and_run(*args, **kwargs)
|
|
2788
3955
|
|
|
2789
3956
|
|
|
2790
|
-
def
|
|
3957
|
+
def _is_parameter_list_or_tuple(value):
|
|
2791
3958
|
"""
|
|
2792
3959
|
Check the type of input in list or tuple is Parameter.
|
|
2793
3960
|
:param value: list or tuple.
|
|
2794
3961
|
:return: The types of all inputs are parameter.
|
|
2795
3962
|
"""
|
|
2796
|
-
|
|
2797
|
-
|
|
2798
|
-
|
|
2799
|
-
|
|
3963
|
+
if isinstance(value, (list, tuple)) and value:
|
|
3964
|
+
for item in value:
|
|
3965
|
+
if not isinstance(item, Parameter):
|
|
3966
|
+
return False
|
|
3967
|
+
return True
|
|
3968
|
+
return False
|
|
3969
|
+
|
|
3970
|
+
|
|
3971
|
+
def _addindent(s_, num_spaces):
|
|
3972
|
+
s = s_.split("\n")
|
|
3973
|
+
# don't do anything for single-line stuff
|
|
3974
|
+
if len(s) == 1:
|
|
3975
|
+
return s_
|
|
3976
|
+
first = s.pop(0)
|
|
3977
|
+
s = [(num_spaces * " ") + line for line in s]
|
|
3978
|
+
s = "\n".join(s)
|
|
3979
|
+
s = first + "\n" + s
|
|
3980
|
+
return s
|