mindspore 2.5.0__cp311-cp311-win_amd64.whl → 2.6.0__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +6 -4
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -33
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +19 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +25 -194
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +109 -75
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +4 -4
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +4 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -1
- mindspore/common/_stub_tensor.py +5 -10
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +2014 -3386
- mindspore/common/api.py +386 -355
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +5 -2
- mindspore/common/dump.py +7 -5
- mindspore/common/file_system.py +3 -0
- mindspore/common/generator.py +3 -0
- mindspore/common/hook_handle.py +5 -3
- mindspore/common/initializer.py +10 -6
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +2 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +106 -39
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +332 -714
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +47 -2
- mindspore/communication/comm_func.py +70 -53
- mindspore/communication/management.py +83 -17
- mindspore/context.py +228 -571
- mindspore/dataset/__init__.py +44 -20
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/core/config.py +3 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +102 -120
- mindspore/dataset/engine/datasets_audio.py +22 -22
- mindspore/dataset/engine/datasets_standard_format.py +43 -24
- mindspore/dataset/engine/datasets_text.py +78 -85
- mindspore/dataset/engine/datasets_user_defined.py +109 -77
- mindspore/dataset/engine/datasets_vision.py +111 -108
- mindspore/dataset/engine/iterators.py +5 -3
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/samplers.py +279 -57
- mindspore/dataset/engine/serializer_deserializer.py +2 -1
- mindspore/dataset/engine/validators.py +10 -0
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/device_context/ascend/op_debug.py +60 -1
- mindspore/device_context/ascend/op_tuning.py +0 -4
- mindspore/device_manager.py +39 -3
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -2
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +22 -26
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +4 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +40 -22
- mindspore/experimental/optim/radam.py +5 -5
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -81
- mindspore/hal/event.py +38 -55
- mindspore/hal/memory.py +115 -147
- mindspore/hal/stream.py +81 -125
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +40 -2
- mindspore/mindrecord/__init__.py +20 -7
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +133 -702
- mindspore/mint/distributed/__init__.py +5 -1
- mindspore/mint/distributed/distributed.py +198 -113
- mindspore/mint/linalg/__init__.py +2 -0
- mindspore/mint/nn/__init__.py +280 -18
- mindspore/mint/nn/functional.py +282 -64
- mindspore/mint/nn/layer/__init__.py +4 -0
- mindspore/mint/nn/layer/_functions.py +7 -3
- mindspore/mint/nn/layer/activation.py +120 -13
- mindspore/mint/nn/layer/conv.py +234 -28
- mindspore/mint/nn/layer/normalization.py +15 -16
- mindspore/mint/nn/layer/padding.py +1 -1
- mindspore/mint/nn/layer/pooling.py +66 -1
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1253 -179
- mindspore/nn/layer/activation.py +23 -21
- mindspore/nn/layer/basic.py +22 -16
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +53 -42
- mindspore/nn/layer/embedding.py +9 -8
- mindspore/nn/layer/normalization.py +48 -42
- mindspore/nn/layer/pooling.py +75 -31
- mindspore/nn/layer/transformer.py +11 -10
- mindspore/nn/learning_rate_schedule.py +4 -2
- mindspore/nn/loss/loss.py +27 -19
- mindspore/nn/optim/ada_grad.py +6 -5
- mindspore/nn/optim/adadelta.py +9 -7
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +18 -14
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +3 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +2 -2
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +9 -7
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +178 -117
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +3 -3
- mindspore/numpy/array_creations.py +3 -3
- mindspore/numpy/array_ops.py +1 -1
- mindspore/numpy/utils.py +1 -2
- mindspore/numpy/utils_const.py +1 -2
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +3 -2
- mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
- mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
- mindspore/ops/_vmap/vmap_array_ops.py +32 -6
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
- mindspore/ops/_vmap/vmap_math_ops.py +4 -7
- mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +127 -52
- mindspore/ops/auto_generate/gen_extend_func.py +286 -208
- mindspore/ops/auto_generate/gen_ops_def.py +2783 -2335
- mindspore/ops/auto_generate/gen_ops_prim.py +8992 -2686
- mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +19 -24
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +4 -5
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +28 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +1631 -2347
- mindspore/ops/function/clip_func.py +38 -45
- mindspore/ops/function/debug_func.py +36 -44
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +46 -78
- mindspore/ops/function/math_func.py +3024 -3855
- mindspore/ops/function/nn_func.py +678 -274
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +17 -30
- mindspore/ops/function/random_func.py +216 -361
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +5 -5
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +26 -18
- mindspore/ops/functional.py +8 -5
- mindspore/ops/functional_overload.py +655 -4
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +21 -14
- mindspore/ops/operations/_custom_ops_utils.py +235 -0
- mindspore/ops/operations/_grad_ops.py +1 -10
- mindspore/ops/operations/_inner_ops.py +5 -76
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +39 -24
- mindspore/ops/operations/comm_ops.py +150 -107
- mindspore/ops/operations/custom_ops.py +287 -32
- mindspore/ops/operations/debug_ops.py +119 -16
- mindspore/ops/operations/inner_ops.py +1 -1
- mindspore/ops/operations/linalg_ops.py +1 -58
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +746 -79
- mindspore/ops/operations/math_ops.py +21 -18
- mindspore/ops/operations/nn_ops.py +67 -224
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +43 -32
- mindspore/ops/tensor_method.py +243 -17
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
- mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
- mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
- mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
- mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
- mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
- mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
- mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
- mindspore/ops_generate/{template.py → common/template.py} +96 -84
- mindspore/ops_generate/gen_ops.py +23 -325
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
- mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -10
- mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
- mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
- mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
- mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
- mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
- mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
- mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
- mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
- mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
- mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
- mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
- mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
- mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
- mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +6 -2
- mindspore/parallel/_auto_parallel_context.py +140 -12
- mindspore/parallel/_cell_wrapper.py +132 -15
- mindspore/parallel/_parallel_serialization.py +95 -4
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +198 -25
- mindspore/parallel/algo_parameter_config.py +3 -3
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +658 -37
- mindspore/parallel/cluster/process_entity/_api.py +151 -19
- mindspore/parallel/cluster/run.py +1 -1
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +258 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +24 -13
- mindspore/parallel/shard.py +137 -62
- mindspore/parallel/transform_safetensors.py +288 -95
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +9 -5
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
- mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +25 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
- mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/common/constant.py +12 -0
- mindspore/profiler/common/msprof_cmd_tool.py +42 -23
- mindspore/profiler/common/path_manager.py +24 -0
- mindspore/profiler/common/profiler_context.py +26 -2
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_parameters.py +59 -18
- mindspore/profiler/common/profiler_path_manager.py +66 -7
- mindspore/profiler/dynamic_profiler.py +112 -79
- mindspore/profiler/envprofiler.py +26 -1
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +57 -14
- mindspore/profiler/platform/npu_profiler.py +33 -7
- mindspore/profiler/profiler.py +541 -45
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +4 -0
- mindspore/profiler/schedule.py +57 -22
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +25 -14
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +2 -2
- mindspore/runtime/executor.py +40 -11
- mindspore/runtime/memory.py +37 -13
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +43 -9
- mindspore/train/amp.py +1 -1
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +24 -40
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_flops_collector.py +2 -3
- mindspore/train/callback/_history.py +7 -4
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +8 -13
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -105
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +4 -5
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +8 -6
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +19 -12
- mindspore/train/model.py +262 -127
- mindspore/train/serialization.py +246 -988
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +3 -2
- mindspore/utils/dryrun.py +4 -2
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +2 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/METADATA +2 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/RECORD +485 -440
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_constants.py +0 -190
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
- /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2020-
|
|
1
|
+
# Copyright 2020-2025 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -15,13 +15,26 @@
|
|
|
15
15
|
"""cell"""
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
|
|
18
|
-
import gc
|
|
19
18
|
import inspect
|
|
20
19
|
import os
|
|
21
20
|
import time
|
|
22
|
-
|
|
23
|
-
import
|
|
24
|
-
|
|
21
|
+
import warnings
|
|
22
|
+
import itertools
|
|
23
|
+
from collections import OrderedDict, namedtuple
|
|
24
|
+
from typing import (
|
|
25
|
+
Dict,
|
|
26
|
+
Optional,
|
|
27
|
+
Set,
|
|
28
|
+
Callable,
|
|
29
|
+
List,
|
|
30
|
+
Tuple,
|
|
31
|
+
Iterator,
|
|
32
|
+
Any,
|
|
33
|
+
TypeVar,
|
|
34
|
+
Mapping
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
import mindspore as ms
|
|
25
38
|
from mindspore._checkparam import args_type_check, check_hook_fn
|
|
26
39
|
from mindspore.common._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic
|
|
27
40
|
from mindspore import log as logger
|
|
@@ -34,19 +47,62 @@ from mindspore import _checkparam as Validator
|
|
|
34
47
|
from mindspore.common import dtype as mstype
|
|
35
48
|
from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache, \
|
|
36
49
|
_no_grad
|
|
37
|
-
from mindspore.common.api import
|
|
50
|
+
from mindspore.common.api import _convert_python_data, _get_args_for_run_predict
|
|
38
51
|
from mindspore.common.api import _process_dyn_args, _generate_dyn_compile_args
|
|
39
|
-
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
52
|
+
from mindspore.common.parameter import _Buffer, Parameter, ParameterTuple
|
|
40
53
|
from mindspore.common.tensor import Tensor
|
|
41
54
|
from mindspore.ops.operations import Cast
|
|
42
55
|
from mindspore.ops.primitive import Primitive
|
|
43
56
|
from mindspore.ops.operations import _inner_ops as inner
|
|
44
57
|
from mindspore.parallel.shard import Shard
|
|
58
|
+
from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
|
|
45
59
|
from mindspore._check_jit_forbidden_api import jit_forbidden_register
|
|
46
60
|
from mindspore.common._decorator import deprecated
|
|
47
61
|
from mindspore.common._register_for_recompute import recompute_registry
|
|
48
62
|
|
|
49
63
|
|
|
64
|
+
__all__ = [
|
|
65
|
+
"register_cell_buffer_registration_hook",
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
_global_buffer_registration_hooks: Dict[int, Callable] = OrderedDict()
|
|
69
|
+
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class _IncompatibleKeys(namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),):
|
|
73
|
+
def __repr__(self):
|
|
74
|
+
if not self.missing_keys and not self.unexpected_keys:
|
|
75
|
+
return "<All keys matched successfully>"
|
|
76
|
+
return super().__repr__()
|
|
77
|
+
|
|
78
|
+
__str__ = __repr__
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def register_cell_buffer_registration_hook(hook: Callable[..., None],):
|
|
82
|
+
r"""Register a buffer registration hook common to all cells.
|
|
83
|
+
|
|
84
|
+
.. warning ::
|
|
85
|
+
|
|
86
|
+
This adds global state to the `nn.Cell` cell
|
|
87
|
+
|
|
88
|
+
The hook will be called every time :func:`register_buffer` is invoked.
|
|
89
|
+
It should have the following signature::
|
|
90
|
+
|
|
91
|
+
hook(cell, name, buffer) -> None or new buffer
|
|
92
|
+
|
|
93
|
+
The hook can modify the input or return a single modified value in the hook.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
A handle that can be used to remove the added hook by calling
|
|
97
|
+
`handle.remove()`.
|
|
98
|
+
"""
|
|
99
|
+
from mindspore.utils.hooks import _RemovableHandle
|
|
100
|
+
handle = _RemovableHandle(_global_buffer_registration_hooks)
|
|
101
|
+
_global_buffer_registration_hooks[handle.id] = hook
|
|
102
|
+
return handle
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
|
|
50
106
|
class Cell(Cell_):
|
|
51
107
|
"""
|
|
52
108
|
The basic building block of neural networks in MindSpore. The model or neural network layer should inherit this
|
|
@@ -108,6 +164,8 @@ class Cell(Cell_):
|
|
|
108
164
|
'_attr_synced', 'pynative', 'requires_grad', 'cell_type',
|
|
109
165
|
'_parameters_forward_hook', '_parameters_backward_hook']
|
|
110
166
|
total_instance_count = 0
|
|
167
|
+
_buffers: Dict[str, Optional[Tensor]]
|
|
168
|
+
_non_persistent_buffers_set: Set[str]
|
|
111
169
|
|
|
112
170
|
def __init__(self, auto_prefix=True, flags=None):
|
|
113
171
|
Cell_.__init__(self, self._cell_tag)
|
|
@@ -115,10 +173,17 @@ class Cell(Cell_):
|
|
|
115
173
|
self.instance_count = Cell.total_instance_count
|
|
116
174
|
self._params = OrderedDict()
|
|
117
175
|
self._cells = OrderedDict()
|
|
176
|
+
super().__setattr__("_buffers", {})
|
|
177
|
+
super().__setattr__("_non_persistent_buffers_set", set())
|
|
178
|
+
super().__setattr__("_state_dict_hooks", OrderedDict())
|
|
179
|
+
super().__setattr__("_state_dict_pre_hooks", OrderedDict())
|
|
180
|
+
super().__setattr__("_load_state_dict_pre_hooks", OrderedDict())
|
|
181
|
+
super().__setattr__("_load_state_dict_post_hooks", OrderedDict())
|
|
118
182
|
self._params_list = OrderedDict()
|
|
119
183
|
self._primitives = OrderedDict()
|
|
120
184
|
self.training = False
|
|
121
185
|
self.requires_grad = False
|
|
186
|
+
self.is_top_cell = False
|
|
122
187
|
self.pynative = False
|
|
123
188
|
self._attr_synced = False
|
|
124
189
|
self._param_prefix = ''
|
|
@@ -135,8 +200,8 @@ class Cell(Cell_):
|
|
|
135
200
|
cells_compile_cache[id(self)] = self.compile_cache
|
|
136
201
|
self.parameter_broadcast_done = False
|
|
137
202
|
self._id = 1
|
|
138
|
-
self.
|
|
139
|
-
self.
|
|
203
|
+
self._exist_objs = None
|
|
204
|
+
self._exist_names = None
|
|
140
205
|
self._recompute_cell = None
|
|
141
206
|
self.mixed_precision_type = None
|
|
142
207
|
self.sig = inspect.signature(self.construct)
|
|
@@ -146,7 +211,6 @@ class Cell(Cell_):
|
|
|
146
211
|
if os.getenv('GC_COLLECT_IN_CELL') == '1':
|
|
147
212
|
logger.warning("The convenient environment 'GC_COLLECT_IN_CELL' is deprecated from version 2.5 "
|
|
148
213
|
"and will be removed in a future version.")
|
|
149
|
-
gc.collect()
|
|
150
214
|
|
|
151
215
|
if flags:
|
|
152
216
|
self.add_flags(**flags)
|
|
@@ -209,6 +273,21 @@ class Cell(Cell_):
|
|
|
209
273
|
def cell_init_args(self):
|
|
210
274
|
return self._cell_init_args
|
|
211
275
|
|
|
276
|
+
@property
|
|
277
|
+
def exist_names(self):
|
|
278
|
+
"""
|
|
279
|
+
Get exist parameter names adding by tuple or list of parameter.
|
|
280
|
+
"""
|
|
281
|
+
if self._exist_names is None:
|
|
282
|
+
self._exist_names = set("")
|
|
283
|
+
return self._exist_names
|
|
284
|
+
|
|
285
|
+
@property
|
|
286
|
+
def exist_objs(self):
|
|
287
|
+
if self._exist_objs is None:
|
|
288
|
+
self._exist_objs = set()
|
|
289
|
+
return self._exist_objs
|
|
290
|
+
|
|
212
291
|
@property
|
|
213
292
|
def param_prefix(self):
|
|
214
293
|
"""
|
|
@@ -237,11 +316,6 @@ class Cell(Cell_):
|
|
|
237
316
|
def bprop_debug(self):
|
|
238
317
|
"""
|
|
239
318
|
Get whether cell custom bprop debug is enabled.
|
|
240
|
-
|
|
241
|
-
Tutorial Examples:
|
|
242
|
-
- `Custom Neural Network Layers - Custom Cell Reverse
|
|
243
|
-
<https://mindspore.cn/docs/en/master/model_train/custom_program/network_custom.html
|
|
244
|
-
#custom-cell-reverse>`_
|
|
245
319
|
"""
|
|
246
320
|
return self._bprop_debug
|
|
247
321
|
|
|
@@ -358,8 +432,6 @@ class Cell(Cell_):
|
|
|
358
432
|
raise ValueError("For 'Cell', the property 'pipeline_stage' "
|
|
359
433
|
"can not be less than 0, but got {}".format(value))
|
|
360
434
|
self._pipeline_stage = value
|
|
361
|
-
for item in self.trainable_params():
|
|
362
|
-
item.add_pipeline_stage(value)
|
|
363
435
|
|
|
364
436
|
@property
|
|
365
437
|
def pipeline_segment(self):
|
|
@@ -395,6 +467,374 @@ class Cell(Cell_):
|
|
|
395
467
|
def enable_backward_hook(self):
|
|
396
468
|
return self._enable_backward_hook
|
|
397
469
|
|
|
470
|
+
@jit_forbidden_register
|
|
471
|
+
def register_buffer(
|
|
472
|
+
self, name: str, tensor: Optional[Tensor], persistent: bool = True
|
|
473
|
+
) -> None:
|
|
474
|
+
r"""Add a buffer to the cell.
|
|
475
|
+
|
|
476
|
+
This is typically used to register a buffer that should not to be
|
|
477
|
+
considered a model parameter. For example, BatchNorm's `running_mean`
|
|
478
|
+
is not a parameter, but is part of the cell's state. Buffers, by
|
|
479
|
+
default, are persistent and will be saved alongside parameters. This
|
|
480
|
+
behavior can be changed by setting `persistent` to ``False`` . The
|
|
481
|
+
only difference between a persistent buffer and a non-persistent buffer
|
|
482
|
+
is that the latter will not be a part of this cell's :attr:`state_dict` .
|
|
483
|
+
|
|
484
|
+
Buffers can be accessed as attributes using given names.
|
|
485
|
+
|
|
486
|
+
Args:
|
|
487
|
+
name (str): name of the buffer. The buffer can be accessed
|
|
488
|
+
from this cell using the given name.
|
|
489
|
+
tensor (Tensor): Buffer to be registered. If ``None`` ,
|
|
490
|
+
the buffer is not included in the cell's :attr:`state_dict` .
|
|
491
|
+
persistent (bool, optional): Whether the buffer is part of this cell's :attr:`state_dict`. Default ``True``.
|
|
492
|
+
|
|
493
|
+
Examples:
|
|
494
|
+
>>> import mindspore
|
|
495
|
+
...
|
|
496
|
+
>>> class Net(mindspore.nn.Cell):
|
|
497
|
+
... def __init__(self):
|
|
498
|
+
... super().__init__()
|
|
499
|
+
... self.register_buffer("buffer0", mindspore.tensor([1, 2, 3]))
|
|
500
|
+
...
|
|
501
|
+
... def construct(self, x):
|
|
502
|
+
... return x + self.net_buffer
|
|
503
|
+
...
|
|
504
|
+
>>> net = Net()
|
|
505
|
+
>>> net.register_buffer("buffer0", mindspore.tensor([4, 5, 6]))
|
|
506
|
+
>>> print(net.buffer0)
|
|
507
|
+
[4 5 6]
|
|
508
|
+
"""
|
|
509
|
+
|
|
510
|
+
if "_buffers" not in self.__dict__:
|
|
511
|
+
raise AttributeError("cannot assign buffer before Cell.__init__() call")
|
|
512
|
+
if not isinstance(name, str):
|
|
513
|
+
raise TypeError(
|
|
514
|
+
f"buffer name should be a string.But got this type: {type(name)}"
|
|
515
|
+
)
|
|
516
|
+
if "." in name:
|
|
517
|
+
raise KeyError('buffer name can\'t contain "."')
|
|
518
|
+
if name == "":
|
|
519
|
+
raise KeyError('buffer name can\'t be empty string ""')
|
|
520
|
+
if hasattr(self, name) and name not in self._buffers:
|
|
521
|
+
raise KeyError(f"attribute '{name}' already exists")
|
|
522
|
+
if tensor is not None and not isinstance(tensor, Tensor):
|
|
523
|
+
raise TypeError(
|
|
524
|
+
f"cannot assign '{type(tensor)}' object to buffer '{name}' "
|
|
525
|
+
"(mindspore Tensor or None required)"
|
|
526
|
+
)
|
|
527
|
+
for hook in _global_buffer_registration_hooks.values():
|
|
528
|
+
output = hook(self, name, tensor)
|
|
529
|
+
if output is not None:
|
|
530
|
+
tensor = output
|
|
531
|
+
if tensor is not None:
|
|
532
|
+
tensor._is_buffer = True
|
|
533
|
+
self._buffers[name] = tensor
|
|
534
|
+
if persistent:
|
|
535
|
+
self._non_persistent_buffers_set.discard(name)
|
|
536
|
+
else:
|
|
537
|
+
self._non_persistent_buffers_set.add(name)
|
|
538
|
+
|
|
539
|
+
@jit_forbidden_register
|
|
540
|
+
def get_buffer(self, target: str) -> "Tensor":
|
|
541
|
+
"""Return the buffer given by `target` if it exists, otherwise throw an error.
|
|
542
|
+
|
|
543
|
+
See the docstring for `get_sub_cell` for a more detailed
|
|
544
|
+
explanation of this method's functionality as well as how to
|
|
545
|
+
correctly specify `target` .
|
|
546
|
+
|
|
547
|
+
Args:
|
|
548
|
+
target (str): The fully-qualified string name of the buffer
|
|
549
|
+
to look for. (See `get_sub_cell` for how to specify a
|
|
550
|
+
fully-qualified string.)
|
|
551
|
+
|
|
552
|
+
Returns:
|
|
553
|
+
Tensor
|
|
554
|
+
|
|
555
|
+
Examples:
|
|
556
|
+
>>> import mindspore
|
|
557
|
+
...
|
|
558
|
+
...
|
|
559
|
+
>>> class NetC(mindspore.nn.Cell):
|
|
560
|
+
... def __init__(self):
|
|
561
|
+
... super().__init__()
|
|
562
|
+
... self.register_buffer("buffer_c", mindspore.tensor([0, 0, 0]))
|
|
563
|
+
...
|
|
564
|
+
... def construct(self, x):
|
|
565
|
+
... return x + self.buffer_c
|
|
566
|
+
...
|
|
567
|
+
...
|
|
568
|
+
>>> class NetB(mindspore.nn.Cell):
|
|
569
|
+
... def __init__(self, net_c):
|
|
570
|
+
... super().__init__()
|
|
571
|
+
... self.net_c = net_c
|
|
572
|
+
... self.register_buffer("buffer_b", mindspore.tensor([1, 2, 3]))
|
|
573
|
+
...
|
|
574
|
+
... def construct(self, x):
|
|
575
|
+
... return self.net_c(x) + self.buffer_b
|
|
576
|
+
...
|
|
577
|
+
...
|
|
578
|
+
>>> class NetA(mindspore.nn.Cell):
|
|
579
|
+
... def __init__(self, net_b):
|
|
580
|
+
... super().__init__()
|
|
581
|
+
... self.net_b = net_b
|
|
582
|
+
... self.register_buffer("buffer_a", mindspore.tensor([4, 5, 6]))
|
|
583
|
+
...
|
|
584
|
+
... def construct(self, x):
|
|
585
|
+
... return self.net_b(x) + self.buffer_a
|
|
586
|
+
...
|
|
587
|
+
...
|
|
588
|
+
>>> net_c = NetC()
|
|
589
|
+
>>> net_b = NetB(net_c)
|
|
590
|
+
>>> net_a = NetA(net_b)
|
|
591
|
+
>>> buffer_c = net_a.get_buffer("net_b.net_c.buffer_c")
|
|
592
|
+
>>> print(f'buffer_c is {buffer_c}')
|
|
593
|
+
buffer_c is [0 0 0]
|
|
594
|
+
|
|
595
|
+
"""
|
|
596
|
+
cell_path, _, buffer_name = target.rpartition(".")
|
|
597
|
+
|
|
598
|
+
cell = self.get_sub_cell(cell_path)
|
|
599
|
+
|
|
600
|
+
if not hasattr(cell, buffer_name):
|
|
601
|
+
raise AttributeError(
|
|
602
|
+
cell._get_name() + " has no attribute `" + buffer_name + "`"
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
buffer = getattr(cell, buffer_name)
|
|
606
|
+
|
|
607
|
+
if buffer_name not in cell._buffers:
|
|
608
|
+
raise AttributeError("`" + buffer_name + "` is not a buffer")
|
|
609
|
+
|
|
610
|
+
return buffer
|
|
611
|
+
|
|
612
|
+
@jit_forbidden_register
|
|
613
|
+
def named_buffers(
|
|
614
|
+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
|
|
615
|
+
) -> Iterator[Tuple[str, Tensor]]:
|
|
616
|
+
r"""Return an iterator over cell buffers, yielding both the name of the buffer as well as the buffer itself.
|
|
617
|
+
|
|
618
|
+
Args:
|
|
619
|
+
prefix (str, optional): prefix to prepend to all buffer names. Default ``""``.
|
|
620
|
+
recurse (bool, optional): if ``True`` , then yields buffers of this cell
|
|
621
|
+
and all sub cells. Otherwise, yields only buffers that
|
|
622
|
+
are direct members of this cell. Default ``True``.
|
|
623
|
+
remove_duplicate (bool, optional): Whether to remove the duplicated buffers in the result. Default ``True``.
|
|
624
|
+
|
|
625
|
+
Returns:
|
|
626
|
+
Iterator[Tuple[str, Tensor]], an iterator of tuple containing the name and buffer.
|
|
627
|
+
|
|
628
|
+
Examples:
|
|
629
|
+
>>> import mindspore
|
|
630
|
+
...
|
|
631
|
+
...
|
|
632
|
+
>>> class NetB(mindspore.nn.Cell):
|
|
633
|
+
... def __init__(self):
|
|
634
|
+
... super().__init__()
|
|
635
|
+
... self.register_buffer("buffer_b", mindspore.tensor([1, 2, 3]))
|
|
636
|
+
...
|
|
637
|
+
... def construct(self, x):
|
|
638
|
+
... return x + self.buffer_b
|
|
639
|
+
...
|
|
640
|
+
...
|
|
641
|
+
>>> class NetA(mindspore.nn.Cell):
|
|
642
|
+
... def __init__(self, net_b):
|
|
643
|
+
... super().__init__()
|
|
644
|
+
... self.net_b = net_b
|
|
645
|
+
... self.register_buffer("buffer_a", mindspore.tensor([4, 5, 6]))
|
|
646
|
+
...
|
|
647
|
+
... def construct(self, x):
|
|
648
|
+
... return self.net_b(x) + self.buffer_a
|
|
649
|
+
...
|
|
650
|
+
...
|
|
651
|
+
>>> net_b = NetB()
|
|
652
|
+
>>> net_a = NetA(net_b)
|
|
653
|
+
>>>
|
|
654
|
+
>>> for name, buffer in net_a.named_buffers():
|
|
655
|
+
... print(f'buffer name is {name}, buffer is {buffer}')
|
|
656
|
+
buffer name is buffer_a, buffer is [4 5 6]
|
|
657
|
+
buffer name is net_b.buffer_b, buffer is [1 2 3]
|
|
658
|
+
|
|
659
|
+
"""
|
|
660
|
+
gen = self._named_members(
|
|
661
|
+
lambda cell: cell._buffers.items(),
|
|
662
|
+
prefix=prefix,
|
|
663
|
+
recurse=recurse,
|
|
664
|
+
remove_duplicate=remove_duplicate,
|
|
665
|
+
)
|
|
666
|
+
yield from gen
|
|
667
|
+
|
|
668
|
+
@jit_forbidden_register
|
|
669
|
+
def buffers(self, recurse: bool = True) -> Iterator[Tensor]:
|
|
670
|
+
r"""Return an iterator over cell buffers.
|
|
671
|
+
|
|
672
|
+
Args:
|
|
673
|
+
recurse (bool, optional): If ``True`` , then yields buffers of this cell
|
|
674
|
+
and all sub cells. Otherwise, yields only buffers that
|
|
675
|
+
are direct members of this cell. Default ``True``.
|
|
676
|
+
|
|
677
|
+
Returns:
|
|
678
|
+
Iterator[Tensor], an iterator of buffer.
|
|
679
|
+
|
|
680
|
+
Examples:
|
|
681
|
+
>>> import mindspore
|
|
682
|
+
...
|
|
683
|
+
...
|
|
684
|
+
>>> class NetB(mindspore.nn.Cell):
|
|
685
|
+
... def __init__(self):
|
|
686
|
+
... super().__init__()
|
|
687
|
+
... self.register_buffer("buffer_b", mindspore.tensor([1, 2, 3]))
|
|
688
|
+
...
|
|
689
|
+
... def construct(self, x):
|
|
690
|
+
... return x + self.buffer_b
|
|
691
|
+
...
|
|
692
|
+
...
|
|
693
|
+
>>> class NetA(mindspore.nn.Cell):
|
|
694
|
+
... def __init__(self, net_b):
|
|
695
|
+
... super().__init__()
|
|
696
|
+
... self.net_b = net_b
|
|
697
|
+
... self.register_buffer("buffer_a", mindspore.tensor([4, 5, 6]))
|
|
698
|
+
...
|
|
699
|
+
... def construct(self, x):
|
|
700
|
+
... return self.net_b(x) + self.buffer_a
|
|
701
|
+
...
|
|
702
|
+
...
|
|
703
|
+
>>> net_b = NetB()
|
|
704
|
+
>>> net_a = NetA(net_b)
|
|
705
|
+
>>>
|
|
706
|
+
>>> for buffer in net_a.buffers():
|
|
707
|
+
... print(f'buffer is {buffer}')
|
|
708
|
+
buffer is [4 5 6]
|
|
709
|
+
buffer is [1 2 3]
|
|
710
|
+
|
|
711
|
+
"""
|
|
712
|
+
for _, buf in self.named_buffers(recurse=recurse):
|
|
713
|
+
yield buf
|
|
714
|
+
|
|
715
|
+
def _named_members(self, get_members_fn, prefix="", recurse=True, remove_duplicate: bool = True):
|
|
716
|
+
r"""Help yield various names + members of cells."""
|
|
717
|
+
memo = set()
|
|
718
|
+
cells = (
|
|
719
|
+
self.cells_and_names(name_prefix=prefix)
|
|
720
|
+
if recurse
|
|
721
|
+
else [(prefix, self)]
|
|
722
|
+
)
|
|
723
|
+
for cell_prefix, cell in cells:
|
|
724
|
+
members = get_members_fn(cell)
|
|
725
|
+
for k, v in members:
|
|
726
|
+
if v is None or v in memo:
|
|
727
|
+
continue
|
|
728
|
+
if remove_duplicate:
|
|
729
|
+
memo.add(v)
|
|
730
|
+
name = cell_prefix + ("." if cell_prefix else "") + k
|
|
731
|
+
yield name, v
|
|
732
|
+
|
|
733
|
+
@jit_forbidden_register
|
|
734
|
+
def get_sub_cell(self, target: str) -> "Cell":
|
|
735
|
+
"""Return the sub cell given by `target` if it exists, otherwise throw an error.
|
|
736
|
+
|
|
737
|
+
For example, let's say you have an ``nn.Cell`` ``A`` that
|
|
738
|
+
looks like this:
|
|
739
|
+
|
|
740
|
+
.. code-block:: text
|
|
741
|
+
|
|
742
|
+
A(
|
|
743
|
+
(net_b): NetB(
|
|
744
|
+
(net_c): NetC(
|
|
745
|
+
(conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
|
|
746
|
+
)
|
|
747
|
+
(dense): Dense(in_features=100, out_features=200, bias=True)
|
|
748
|
+
)
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
(The diagram shows an ``nn.Cell`` ``A``. ``A`` has a nested
|
|
752
|
+
sub cell ``net_b``, which itself has two sub cells ``net_c``
|
|
753
|
+
and ``dense``. ``net_c`` then has a sub cell ``conv``.)
|
|
754
|
+
|
|
755
|
+
To check whether we have the ``dense`` sub cell, we
|
|
756
|
+
would call `get_sub_cell("net_b.dense")`. To check whether
|
|
757
|
+
we have the ``conv`` sub cell, we would call
|
|
758
|
+
`get_sub_cell("net_b.net_c.conv")`.
|
|
759
|
+
|
|
760
|
+
The runtime of ``get_sub_cell`` is bounded by the degree
|
|
761
|
+
of cell nesting in `target`. A query against
|
|
762
|
+
`name_cells` achieves the same result, but it is O(N) in
|
|
763
|
+
the number of transitive cells. So, for a simple check to see
|
|
764
|
+
if some sub cells exist, ``get_sub_cell`` should always be
|
|
765
|
+
used.
|
|
766
|
+
|
|
767
|
+
Args:
|
|
768
|
+
target (str): The fully-qualified string name of the sub cell
|
|
769
|
+
to look for. (See above example for how to specify a
|
|
770
|
+
fully-qualified string.)
|
|
771
|
+
|
|
772
|
+
Returns:
|
|
773
|
+
Cell
|
|
774
|
+
|
|
775
|
+
Examples:
|
|
776
|
+
>>> import mindspore
|
|
777
|
+
...
|
|
778
|
+
...
|
|
779
|
+
>>> class NetC(mindspore.nn.Cell):
|
|
780
|
+
... def __init__(self):
|
|
781
|
+
... super().__init__()
|
|
782
|
+
... self.register_buffer("buffer_c", mindspore.tensor([0, 0, 0]))
|
|
783
|
+
... self.dense_c = mindspore.nn.Dense(5, 3)
|
|
784
|
+
...
|
|
785
|
+
... def construct(self, x):
|
|
786
|
+
... return self.dense_c(x) + self.buffer_c
|
|
787
|
+
...
|
|
788
|
+
...
|
|
789
|
+
>>> class NetB(mindspore.nn.Cell):
|
|
790
|
+
... def __init__(self, net_c):
|
|
791
|
+
... super().__init__()
|
|
792
|
+
... self.net_c = net_c
|
|
793
|
+
... self.register_buffer("buffer_b", mindspore.tensor([1, 2, 3]))
|
|
794
|
+
...
|
|
795
|
+
... def construct(self, x):
|
|
796
|
+
... return self.net_c(x) + self.buffer_b
|
|
797
|
+
...
|
|
798
|
+
...
|
|
799
|
+
>>> class NetA(mindspore.nn.Cell):
|
|
800
|
+
... def __init__(self, net_b):
|
|
801
|
+
... super().__init__()
|
|
802
|
+
... self.net_b = net_b
|
|
803
|
+
... self.register_buffer("buffer_a", mindspore.tensor([4, 5, 6]))
|
|
804
|
+
...
|
|
805
|
+
... def construct(self, x):
|
|
806
|
+
... return self.net_b(x) + self.buffer_a
|
|
807
|
+
...
|
|
808
|
+
...
|
|
809
|
+
>>> net_c = NetC()
|
|
810
|
+
>>> net_b = NetB(net_c)
|
|
811
|
+
>>> net_a = NetA(net_b)
|
|
812
|
+
>>> net_c = net_a.get_sub_cell("net_b.net_c")
|
|
813
|
+
>>> print(f'net_c is {net_c}')
|
|
814
|
+
net_c is NetC(
|
|
815
|
+
(dense_c): Dense(input_channels=5, output_channels=3, has_bias=True)
|
|
816
|
+
)
|
|
817
|
+
|
|
818
|
+
"""
|
|
819
|
+
if target == "":
|
|
820
|
+
return self
|
|
821
|
+
|
|
822
|
+
atoms: List[str] = target.split(".")
|
|
823
|
+
cell = self
|
|
824
|
+
|
|
825
|
+
for item in atoms:
|
|
826
|
+
if not hasattr(cell, item):
|
|
827
|
+
raise AttributeError(
|
|
828
|
+
cell._get_name() + " has no " "attribute `" + item + "`"
|
|
829
|
+
)
|
|
830
|
+
|
|
831
|
+
cell = getattr(cell, item)
|
|
832
|
+
|
|
833
|
+
if not isinstance(cell, Cell):
|
|
834
|
+
raise AttributeError("`" + item + "` is not " "an nn.Cell")
|
|
835
|
+
|
|
836
|
+
return cell
|
|
837
|
+
|
|
398
838
|
def get_func_graph_proto(self):
|
|
399
839
|
"""Return graph binary proto."""
|
|
400
840
|
exec_id = ".".join([self.phase, str(self.create_time), str(id(self))])
|
|
@@ -405,6 +845,10 @@ class Cell(Cell_):
|
|
|
405
845
|
params = self.__dict__['_params']
|
|
406
846
|
if name in params:
|
|
407
847
|
return params[name]
|
|
848
|
+
if '_buffers' in self.__dict__:
|
|
849
|
+
buffers = self.__dict__['_buffers']
|
|
850
|
+
if name in buffers:
|
|
851
|
+
return buffers[name]
|
|
408
852
|
if '_cells' in self.__dict__:
|
|
409
853
|
cells = self.__dict__['_cells']
|
|
410
854
|
if name in cells:
|
|
@@ -427,6 +871,8 @@ class Cell(Cell_):
|
|
|
427
871
|
def __delattr__(self, name):
|
|
428
872
|
if name in self._params:
|
|
429
873
|
del self._params[name]
|
|
874
|
+
elif name in self._buffers:
|
|
875
|
+
del self._buffers[name]
|
|
430
876
|
elif name in self._cells:
|
|
431
877
|
del self._cells[name]
|
|
432
878
|
elif '_params_list' in self.__dict__ and name in self._params_list:
|
|
@@ -600,6 +1046,89 @@ class Cell(Cell_):
|
|
|
600
1046
|
for prim in all_prims:
|
|
601
1047
|
prim.add_prim_attr("strategy_gen_mode", "data_parallel")
|
|
602
1048
|
|
|
1049
|
+
def offload(self, backward_prefetch="Auto"):
|
|
1050
|
+
"""
|
|
1051
|
+
Set the cell offload. All primitive ops in the cell will be set offload. For the intermediate
|
|
1052
|
+
activations calculated by these primitive ops, we will not save them in the forward pass, but
|
|
1053
|
+
offload them and onload them in the backward pass.
|
|
1054
|
+
|
|
1055
|
+
Note:
|
|
1056
|
+
- If Cell.offload is called, the mode should be set to "GRAPH_MODE".
|
|
1057
|
+
- If Cell.offload is called, lazyinline should be enabled.
|
|
1058
|
+
|
|
1059
|
+
Args:
|
|
1060
|
+
backward_prefetch(Union[str, int], optional): The timing for prefetching activations in advance in backward
|
|
1061
|
+
pass. Default: ``"Auto"``. If set it to ``"Auto"``, framework
|
|
1062
|
+
will start to prefetch activations one operator in advance.
|
|
1063
|
+
If set it to a positive int value, framework will start to
|
|
1064
|
+
prefetch activations ``backward_prefetch`` operators in
|
|
1065
|
+
advance, such as 1, 20, 100.
|
|
1066
|
+
Examples:
|
|
1067
|
+
>>> import mindspore.nn as nn
|
|
1068
|
+
>>> from mindspore import ops
|
|
1069
|
+
>>> from mindspore.common import Tensor, Parameter
|
|
1070
|
+
>>> from mindspore.common.lazy_inline import lazy_inline
|
|
1071
|
+
>>>
|
|
1072
|
+
>>> class Block(nn.Cell):
|
|
1073
|
+
... def __init__(self):
|
|
1074
|
+
... super(Block, self).__init__()
|
|
1075
|
+
... self.transpose1 = ops.Transpose()
|
|
1076
|
+
... self.transpose2 = ops.Transpose()
|
|
1077
|
+
... self.transpose3 = ops.Transpose()
|
|
1078
|
+
... self.transpose4 = ops.Transpose()
|
|
1079
|
+
... self.real_div1 = ops.RealDiv()
|
|
1080
|
+
... self.real_div2 = ops.RealDiv()
|
|
1081
|
+
... self.batch_matmul1 = ops.BatchMatMul()
|
|
1082
|
+
... self.batch_matmul2 = ops.BatchMatMul()
|
|
1083
|
+
... self.softmax = ops.Softmax(-1)
|
|
1084
|
+
... self.expand_dims = ops.ExpandDims()
|
|
1085
|
+
... self.sub = ops.Sub()
|
|
1086
|
+
... self.y = Parameter(Tensor(np.ones((1024, 128, 128)).astype(np.float32)))
|
|
1087
|
+
... def construct(self, x):
|
|
1088
|
+
... transpose1 = self.transpose1(x, (0, 2, 1, 3))
|
|
1089
|
+
... real_div1 = self.real_div1(transpose1, Tensor(2.37891))
|
|
1090
|
+
... transpose2 = self.transpose2(x, (0, 2, 3, 1))
|
|
1091
|
+
... real_div2 = self.real_div2(transpose2, Tensor(2.37891))
|
|
1092
|
+
... batch_matmul1 = self.batch_matmul1(real_div1, real_div2)
|
|
1093
|
+
... expand_dims = self.expand_dims(self.y, 1)
|
|
1094
|
+
... sub = self.sub(Tensor([1.0]), expand_dims)
|
|
1095
|
+
... soft_max = self.softmax(sub)
|
|
1096
|
+
... transpose3 = self.transpose3(x, (0, 2, 1, 3))
|
|
1097
|
+
... batch_matmul2 = self.batch_matmul2(soft_max[0], transpose3)
|
|
1098
|
+
... transpose4 = self.transpose4(batch_matmul2, (0, 2, 1, 3))
|
|
1099
|
+
... return transpose4
|
|
1100
|
+
>>>
|
|
1101
|
+
>>> class OuterBlock(nn.Cell):
|
|
1102
|
+
... @lazy_inline
|
|
1103
|
+
... def __init__(self):
|
|
1104
|
+
... super(OuterBlock, self).__init__()
|
|
1105
|
+
... self.block = Block()
|
|
1106
|
+
... def construct(self, x):
|
|
1107
|
+
... return self.block(x)
|
|
1108
|
+
>>>
|
|
1109
|
+
>>> class Nets(nn.Cell):
|
|
1110
|
+
... def __init__(self):
|
|
1111
|
+
... super(Nets, self).__init__()
|
|
1112
|
+
... self.blocks = nn.CellList()
|
|
1113
|
+
... for _ in range(3):
|
|
1114
|
+
... b = OuterBlock()
|
|
1115
|
+
... b.offload()
|
|
1116
|
+
... self.blocks.append(b)
|
|
1117
|
+
... def construct(self, x):
|
|
1118
|
+
... out = x
|
|
1119
|
+
... for i in range(3):
|
|
1120
|
+
... out = self.blocks[i](out)
|
|
1121
|
+
... return out
|
|
1122
|
+
"""
|
|
1123
|
+
if context._get_mode() == context.PYNATIVE_MODE:
|
|
1124
|
+
raise ValueError("The Cell offload does not support PyNative mode now.")
|
|
1125
|
+
if isinstance(backward_prefetch, str):
|
|
1126
|
+
Validator.check_string(backward_prefetch, ['Auto'], 'backward_prefetch', self.cls_name)
|
|
1127
|
+
else:
|
|
1128
|
+
Validator.check_non_negative_int(backward_prefetch)
|
|
1129
|
+
for prim in self._get_prims_recursively():
|
|
1130
|
+
prim._offload(backward_prefetch=backward_prefetch)
|
|
1131
|
+
|
|
603
1132
|
def shard(self, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
|
|
604
1133
|
"""
|
|
605
1134
|
Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be
|
|
@@ -612,9 +1141,9 @@ class Cell(Cell_):
|
|
|
612
1141
|
The parallel strategies of remaining operators are derived from the strategy specified by the input and output.
|
|
613
1142
|
|
|
614
1143
|
Note:
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
If the input contain Parameter, its strategy should be set in `in_strategy`.
|
|
1144
|
+
- It is valid only in semi auto parallel or auto parallel mode.
|
|
1145
|
+
In other parallel modes, strategies set here will be ignored.
|
|
1146
|
+
- If the input contain Parameter, its strategy should be set in `in_strategy`.
|
|
618
1147
|
|
|
619
1148
|
Args:
|
|
620
1149
|
in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple. Tuple
|
|
@@ -628,7 +1157,7 @@ class Cell(Cell_):
|
|
|
628
1157
|
If the parameter name is incorrect or the corresponding parameter
|
|
629
1158
|
has been set, the parameter setting will be ignored.
|
|
630
1159
|
Default: ``None`` .
|
|
631
|
-
device (
|
|
1160
|
+
device (str): Select a certain device target. It is not in use right now.
|
|
632
1161
|
Support [ ``"CPU"`` , ``"GPU"`` , ``"Ascend"`` ]. Default: ``"Ascend"`` .
|
|
633
1162
|
level (int): Option for parallel strategy infer algorithm, namely the object function, maximize computation
|
|
634
1163
|
over communication ratio, maximize speed performance, minimize memory usage etc. It is not in
|
|
@@ -660,10 +1189,8 @@ class Cell(Cell_):
|
|
|
660
1189
|
... x = self.block2_shard(x)
|
|
661
1190
|
... return x
|
|
662
1191
|
"""
|
|
663
|
-
if
|
|
664
|
-
|
|
665
|
-
f"Please check the parallel mode in parallel context.")
|
|
666
|
-
|
|
1192
|
+
if ms.communication.management.get_group_size() == 1:
|
|
1193
|
+
return self
|
|
667
1194
|
shard_fn = Shard()
|
|
668
1195
|
fn = shard_fn(self, in_strategy, out_strategy, parameter_plan, device, level)
|
|
669
1196
|
self._shard_fn = fn
|
|
@@ -766,7 +1293,8 @@ class Cell(Cell_):
|
|
|
766
1293
|
"""
|
|
767
1294
|
Process cell info before call construct
|
|
768
1295
|
"""
|
|
769
|
-
if self.requires_grad:
|
|
1296
|
+
if self.requires_grad and (not _pynative_executor.grad_flag() or _pynative_executor.high_order()):
|
|
1297
|
+
self.is_top_cell = True
|
|
770
1298
|
_pynative_executor.set_grad_flag(True)
|
|
771
1299
|
_pynative_executor.new_graph(self, *args, **kwargs)
|
|
772
1300
|
elif self._dynamic_shape_inputs is not None:
|
|
@@ -780,8 +1308,9 @@ class Cell(Cell_):
|
|
|
780
1308
|
"""
|
|
781
1309
|
Process cell info after call construct
|
|
782
1310
|
"""
|
|
783
|
-
if self.requires_grad:
|
|
1311
|
+
if self.requires_grad and self.is_top_cell:
|
|
784
1312
|
_pynative_executor.end_graph(self, output, *args, **kwargs)
|
|
1313
|
+
self.is_top_cell = False
|
|
785
1314
|
elif self._dynamic_shape_inputs is not None:
|
|
786
1315
|
_pynative_executor.set_cell_use_dynamic_shape_process(False)
|
|
787
1316
|
|
|
@@ -826,52 +1355,41 @@ class Cell(Cell_):
|
|
|
826
1355
|
self._add_attr(key, value)
|
|
827
1356
|
self._attr_synced = True
|
|
828
1357
|
|
|
829
|
-
def
|
|
830
|
-
"""Set attr for
|
|
831
|
-
|
|
832
|
-
params = self.__dict__.get('_params')
|
|
833
|
-
if params is None:
|
|
834
|
-
raise AttributeError("For 'Cell', can not assign params before Cell.__init__() is called.")
|
|
835
|
-
if name in self.__dict__:
|
|
836
|
-
if self.__dict__[name] is not None:
|
|
837
|
-
raise TypeError(f"For 'Cell', the {name} should not be Parameter.")
|
|
838
|
-
del self.__dict__[name]
|
|
839
|
-
if cells and name in cells:
|
|
840
|
-
raise TypeError(f"For 'Cell', the {name} must be Cell, but got Parameter.")
|
|
841
|
-
self.insert_param_to_cell(name, value)
|
|
842
|
-
|
|
843
|
-
def _set_attr_for_parameter_tuple(self, name, value):
|
|
844
|
-
"""Set attr for parameter in ParameterTuple."""
|
|
845
|
-
params = self.__dict__.get('_params')
|
|
846
|
-
params_list = self.__dict__.get('_params_list')
|
|
847
|
-
if params is None:
|
|
848
|
-
raise AttributeError("For 'Cell', can not assign params before Cell.__init__() is called.")
|
|
849
|
-
exist_names = set("")
|
|
850
|
-
exist_objs = set()
|
|
851
|
-
for item in value:
|
|
852
|
-
if item in exist_objs:
|
|
853
|
-
# If there are multiple identical objects, their names only check once.
|
|
854
|
-
continue
|
|
855
|
-
exist_objs.add(item)
|
|
856
|
-
if item.name == PARAMETER_NAME_DEFAULT:
|
|
857
|
-
logger.warning("For 'Cell', the parameter definition is deprecated.\n"
|
|
858
|
-
"Please set a unique name for the parameter in ParameterTuple '{}'.".format(value))
|
|
859
|
-
item.name = item.name + "$" + str(self._id)
|
|
860
|
-
self._id += 1
|
|
861
|
-
self.insert_param_to_cell(item.name, item, check_name_contain_dot=False)
|
|
862
|
-
if item.name in exist_names:
|
|
863
|
-
raise ValueError("The value {} , its name '{}' already exists. "
|
|
864
|
-
"Please set a unique name for the parameter.".format(value, item.name))
|
|
865
|
-
exist_names.add(item.name)
|
|
866
|
-
|
|
867
|
-
if context._get_mode() == context.PYNATIVE_MODE:
|
|
1358
|
+
def _set_attr_for_param_or_param_tuple(self, name, value):
|
|
1359
|
+
"""Set attr for param and tensor."""
|
|
1360
|
+
if isinstance(value, Parameter):
|
|
868
1361
|
if name in self.__dict__:
|
|
869
1362
|
del self.__dict__[name]
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
1363
|
+
self.insert_param_to_cell(name, value)
|
|
1364
|
+
elif isinstance(value, ParameterTuple):
|
|
1365
|
+
exist_names = set("")
|
|
1366
|
+
exist_objs = set()
|
|
1367
|
+
for item in value:
|
|
1368
|
+
if item in exist_objs:
|
|
1369
|
+
# If there are multiple identical objects, their names only check once.
|
|
1370
|
+
continue
|
|
1371
|
+
exist_objs.add(item)
|
|
1372
|
+
if item.name == PARAMETER_NAME_DEFAULT:
|
|
1373
|
+
logger.warning("For 'Cell', the parameter definition is deprecated.\n"
|
|
1374
|
+
"Please set a unique name for the parameter in ParameterTuple '{}'.".format(value))
|
|
1375
|
+
item.name = item.name + "$" + str(self._id)
|
|
1376
|
+
self._id += 1
|
|
1377
|
+
self.insert_param_to_cell(item.name, item, check_name_contain_dot=False)
|
|
1378
|
+
if item.name in exist_names:
|
|
1379
|
+
raise ValueError("The value {} , its name '{}' already exists. "
|
|
1380
|
+
"Please set a unique name for the parameter.".format(value, item.name))
|
|
1381
|
+
exist_names.add(item.name)
|
|
1382
|
+
|
|
1383
|
+
if context._get_mode() == context.PYNATIVE_MODE:
|
|
1384
|
+
if name in self.__dict__:
|
|
1385
|
+
del self.__dict__[name]
|
|
1386
|
+
params = self.__dict__.get('_params')
|
|
1387
|
+
if name in params:
|
|
1388
|
+
del params[name]
|
|
1389
|
+
params_list = self.__dict__.get('_params_list')
|
|
1390
|
+
params_list[name] = value
|
|
1391
|
+
else:
|
|
1392
|
+
object.__setattr__(self, name, value)
|
|
875
1393
|
|
|
876
1394
|
def _set_attr_for_parameter_in_list_or_tuple(self, name, value):
|
|
877
1395
|
"""Set attr for parameter in list or tuple."""
|
|
@@ -884,24 +1402,18 @@ class Cell(Cell_):
|
|
|
884
1402
|
item.name = item.name + "$" + str(self._id)
|
|
885
1403
|
self._id += 1
|
|
886
1404
|
if item.name in self.exist_names:
|
|
887
|
-
raise ValueError("The value {} , its name '{}' already exists. "
|
|
888
|
-
"Please set a unique name for the parameter."
|
|
1405
|
+
raise ValueError(f"The value {value} , its name '{item.name}' already exists. "
|
|
1406
|
+
"Please set a unique name for the parameter.")
|
|
889
1407
|
self.exist_names.add(item.name)
|
|
890
1408
|
object.__setattr__(self, name, value)
|
|
891
1409
|
|
|
892
1410
|
def _set_attr_for_cell(self, name, value):
|
|
893
1411
|
"""Set attr for cell."""
|
|
894
|
-
cells = self.__dict__.get('_cells')
|
|
895
|
-
params = self.__dict__.get('_params')
|
|
896
|
-
if cells is None:
|
|
897
|
-
raise AttributeError("For 'Cell', can not assign cells before Cell.__init__() is called.")
|
|
898
1412
|
if name in self.__dict__:
|
|
899
1413
|
del self.__dict__[name]
|
|
900
|
-
if params and name in params:
|
|
901
|
-
raise TypeError(f"For 'Cell', the {name} must be Parameter, but got Cell.")
|
|
902
1414
|
if self._auto_prefix:
|
|
903
1415
|
value.update_parameters_name(name + '.')
|
|
904
|
-
|
|
1416
|
+
self.insert_child_to_cell(name, value)
|
|
905
1417
|
if hasattr(self, '_cell_init_args'):
|
|
906
1418
|
self.cell_init_args += str({name: value})
|
|
907
1419
|
|
|
@@ -914,30 +1426,57 @@ class Cell(Cell_):
|
|
|
914
1426
|
else:
|
|
915
1427
|
self.insert_param_to_cell(name, None)
|
|
916
1428
|
|
|
917
|
-
def
|
|
918
|
-
|
|
1429
|
+
def _set_attr_for_object(self, name, value):
|
|
1430
|
+
"""Set attr for py object."""
|
|
919
1431
|
params = self.__dict__.get('_params')
|
|
920
|
-
if
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
1432
|
+
if params is not None and name in params:
|
|
1433
|
+
if value is not None:
|
|
1434
|
+
if isinstance(value, Tensor):
|
|
1435
|
+
params[name].set_data(value)
|
|
1436
|
+
return
|
|
1437
|
+
raise TypeError(
|
|
1438
|
+
f"Parameter '{name}' already exists in network, "
|
|
1439
|
+
f"can not assign this type: '{type(value)}' as a parameter.")
|
|
1440
|
+
params[name] = None
|
|
1441
|
+
return
|
|
1442
|
+
cells = self.__dict__.get('_cells')
|
|
1443
|
+
if cells is not None and name in cells:
|
|
1444
|
+
if value is not None:
|
|
1445
|
+
raise TypeError(
|
|
1446
|
+
f"Sub cell '{name}' already exists in network, "
|
|
1447
|
+
f"can not assign this type: '{type(value)}' as a cell.")
|
|
1448
|
+
cells[name] = None
|
|
1449
|
+
return
|
|
1450
|
+
buffers = self.__dict__.get('_buffers')
|
|
1451
|
+
if buffers is not None and name in buffers:
|
|
1452
|
+
if value is not None:
|
|
1453
|
+
raise TypeError(
|
|
1454
|
+
f"Buffer '{name}' already exists in network, "
|
|
1455
|
+
f"can not assign this type: '{type(value)}' as a buffer.")
|
|
1456
|
+
buffers[name] = None
|
|
1457
|
+
return
|
|
1458
|
+
object.__setattr__(self, name, value)
|
|
1459
|
+
|
|
1460
|
+
def __setattr__(self, name, value):
|
|
1461
|
+
if isinstance(value, (Parameter, ParameterTuple)):
|
|
1462
|
+
self._set_attr_for_param_or_param_tuple(name, value)
|
|
1463
|
+
elif _is_parameter_list_or_tuple(value):
|
|
925
1464
|
self._set_attr_for_parameter_in_list_or_tuple(name, value)
|
|
926
1465
|
elif isinstance(value, Cell):
|
|
927
1466
|
self._set_attr_for_cell(name, value)
|
|
928
|
-
elif
|
|
929
|
-
self.
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
if isinstance(value, Primitive):
|
|
936
|
-
value.set_prim_instance_name(name)
|
|
937
|
-
self._primitives[name] = value
|
|
1467
|
+
elif isinstance(value, _Buffer):
|
|
1468
|
+
if name in self.__dict__:
|
|
1469
|
+
del self.__dict__[name]
|
|
1470
|
+
self.register_buffer(name, value)
|
|
1471
|
+
elif isinstance(value, Primitive):
|
|
1472
|
+
value.set_prim_instance_name(name)
|
|
1473
|
+
self._primitives[name] = value
|
|
938
1474
|
object.__setattr__(self, name, value)
|
|
939
|
-
|
|
940
|
-
self.
|
|
1475
|
+
else:
|
|
1476
|
+
self._set_attr_for_object(name, value)
|
|
1477
|
+
|
|
1478
|
+
def _get_name(self):
|
|
1479
|
+
return self.__class__.__name__
|
|
941
1480
|
|
|
942
1481
|
def extend_repr(self):
|
|
943
1482
|
"""
|
|
@@ -951,19 +1490,28 @@ class Cell(Cell_):
|
|
|
951
1490
|
return self.__repr__()
|
|
952
1491
|
|
|
953
1492
|
def __repr__(self):
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
1493
|
+
extra_lines = []
|
|
1494
|
+
extend_repr = self.extend_repr()
|
|
1495
|
+
# empty string will be split into list ['']
|
|
1496
|
+
if extend_repr:
|
|
1497
|
+
extra_lines = extend_repr.split("\n")
|
|
1498
|
+
child_lines = []
|
|
1499
|
+
for key, cell in self._cells.items():
|
|
1500
|
+
cell_str = repr(cell)
|
|
1501
|
+
cell_str = _addindent(cell_str, 2)
|
|
1502
|
+
child_lines.append("(" + key + "): " + cell_str)
|
|
1503
|
+
lines = extra_lines + child_lines
|
|
1504
|
+
|
|
1505
|
+
main_str = self._get_name() + "("
|
|
1506
|
+
if lines:
|
|
1507
|
+
# simple one-liner info, which most builtin Modules will use
|
|
1508
|
+
if len(extra_lines) == 1 and not child_lines:
|
|
1509
|
+
main_str += extra_lines[0]
|
|
1510
|
+
else:
|
|
1511
|
+
main_str += "\n " + "\n ".join(lines) + "\n"
|
|
1512
|
+
|
|
1513
|
+
main_str += ")"
|
|
1514
|
+
return main_str
|
|
967
1515
|
|
|
968
1516
|
def load_parameter_slice(self, params):
|
|
969
1517
|
"""
|
|
@@ -1129,9 +1677,11 @@ class Cell(Cell_):
|
|
|
1129
1677
|
args (tuple): Args of the Cell object.
|
|
1130
1678
|
kwargs (dict): Kwargs of the Cell object.
|
|
1131
1679
|
"""
|
|
1680
|
+
_init_auto_parallel_context(self)
|
|
1132
1681
|
self._compile_args = self._get_compile_args(args)
|
|
1133
1682
|
_cell_graph_executor.compile(self, *self._compile_args, phase=self.phase,
|
|
1134
1683
|
jit_config_dict=self._jit_config_dict, **kwargs)
|
|
1684
|
+
_clear_auto_parallel_context(self)
|
|
1135
1685
|
|
|
1136
1686
|
def compile_and_run(self, *args, **kwargs):
|
|
1137
1687
|
"""
|
|
@@ -1262,9 +1812,9 @@ class Cell(Cell_):
|
|
|
1262
1812
|
>>> net2 = nn.Dense(2, 2)
|
|
1263
1813
|
>>> net1.insert_child_to_cell("child", net2)
|
|
1264
1814
|
>>> print(net1)
|
|
1265
|
-
ReLU
|
|
1266
|
-
(child): Dense
|
|
1267
|
-
|
|
1815
|
+
ReLU(
|
|
1816
|
+
(child): Dense(input_channels=2, output_channels=2, has_bias=True)
|
|
1817
|
+
)
|
|
1268
1818
|
"""
|
|
1269
1819
|
if not isinstance(child_name, str):
|
|
1270
1820
|
raise TypeError(f"For 'insert_child_to_cell', the type of parameter 'child_name' must be str, "
|
|
@@ -1322,13 +1872,22 @@ class Cell(Cell_):
|
|
|
1322
1872
|
new_param_tuple.append(param)
|
|
1323
1873
|
cell.__dict__[key] = ParameterTuple(new_param_tuple)
|
|
1324
1874
|
|
|
1875
|
+
def _get_cell_parallel_mode(self):
|
|
1876
|
+
"""Determine whether the current cell is in parallel mode."""
|
|
1877
|
+
is_parallel_mode = False
|
|
1878
|
+
for _, param in self.parameters_and_names():
|
|
1879
|
+
if param.param_info.is_param_init:
|
|
1880
|
+
is_parallel_mode = True
|
|
1881
|
+
break
|
|
1882
|
+
return is_parallel_mode
|
|
1883
|
+
|
|
1325
1884
|
def init_parameters_data(self, auto_parallel_mode=False):
|
|
1326
1885
|
"""
|
|
1327
1886
|
Initialize all parameters and replace the original saved parameters in cell.
|
|
1328
1887
|
|
|
1329
1888
|
Note:
|
|
1330
1889
|
trainable_params() and other similar interfaces may return different parameter instance after
|
|
1331
|
-
`init_parameters_data
|
|
1890
|
+
`init_parameters_data`. It is not recommended to save these results.
|
|
1332
1891
|
|
|
1333
1892
|
Args:
|
|
1334
1893
|
auto_parallel_mode (bool): If running in auto_parallel_mode. Default: ``False`` .
|
|
@@ -1366,9 +1925,18 @@ class Cell(Cell_):
|
|
|
1366
1925
|
|
|
1367
1926
|
# replace all original usage.
|
|
1368
1927
|
cells = self.cells_and_names()
|
|
1928
|
+
is_parallel_mode = self._get_cell_parallel_mode()
|
|
1929
|
+
is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
|
|
1930
|
+
|
|
1369
1931
|
for _, cell in cells:
|
|
1370
1932
|
params = cell._params.items()
|
|
1371
1933
|
for param_name, param in params:
|
|
1934
|
+
not_sliced = not param.sliced
|
|
1935
|
+
judgment = not_sliced
|
|
1936
|
+
if param.param_info.is_pipeline_shared_param:
|
|
1937
|
+
continue
|
|
1938
|
+
if is_graph_mode and is_parallel_mode and judgment:
|
|
1939
|
+
continue
|
|
1372
1940
|
if not auto_parallel_mode:
|
|
1373
1941
|
cell._params[param_name] = _updata(param)
|
|
1374
1942
|
continue
|
|
@@ -1380,6 +1948,12 @@ class Cell(Cell_):
|
|
|
1380
1948
|
param_tuple = cell_dict[key]
|
|
1381
1949
|
new_param_tuple = []
|
|
1382
1950
|
for param in param_tuple:
|
|
1951
|
+
not_sliced = not param.sliced
|
|
1952
|
+
judgment = not_sliced
|
|
1953
|
+
if param.param_info.is_pipeline_shared_param:
|
|
1954
|
+
continue
|
|
1955
|
+
if is_graph_mode and is_parallel_mode and judgment:
|
|
1956
|
+
continue
|
|
1383
1957
|
if not auto_parallel_mode:
|
|
1384
1958
|
new_param_tuple.append(_updata(param))
|
|
1385
1959
|
continue
|
|
@@ -1687,7 +2261,7 @@ class Cell(Cell_):
|
|
|
1687
2261
|
... return x
|
|
1688
2262
|
>>> net = Net()
|
|
1689
2263
|
>>> print(net.cells())
|
|
1690
|
-
odict_values([Dense
|
|
2264
|
+
odict_values([Dense(input_channels=2, output_channels=2, has_bias=True)])
|
|
1691
2265
|
"""
|
|
1692
2266
|
return self.name_cells().values()
|
|
1693
2267
|
|
|
@@ -1748,7 +2322,7 @@ class Cell(Cell_):
|
|
|
1748
2322
|
... return x
|
|
1749
2323
|
>>> net = Net()
|
|
1750
2324
|
>>> print(net.name_cells())
|
|
1751
|
-
OrderedDict([('dense', Dense
|
|
2325
|
+
OrderedDict([('dense', Dense(input_channels=2, output_channels=2, has_bias=True))])
|
|
1752
2326
|
"""
|
|
1753
2327
|
value_set = set()
|
|
1754
2328
|
cells = OrderedDict()
|
|
@@ -1789,10 +2363,10 @@ class Cell(Cell_):
|
|
|
1789
2363
|
... if isinstance(cell, nn.Dense):
|
|
1790
2364
|
... cell.weight.set_data(initializer(One(), cell.weight.shape, cell.weight.dtype))
|
|
1791
2365
|
>>> net.apply(func)
|
|
1792
|
-
SequentialCell
|
|
1793
|
-
(0): Dense
|
|
1794
|
-
(1): Dense
|
|
1795
|
-
|
|
2366
|
+
SequentialCell(
|
|
2367
|
+
(0): Dense(input_channels=2, output_channels=2, has_bias=True)
|
|
2368
|
+
(1): Dense(input_channels=2, output_channels=2, has_bias=True)
|
|
2369
|
+
)
|
|
1796
2370
|
>>> print(net[0].weight.asnumpy())
|
|
1797
2371
|
[[1. 1.]
|
|
1798
2372
|
[1. 1.]]
|
|
@@ -1832,9 +2406,6 @@ class Cell(Cell_):
|
|
|
1832
2406
|
if not hasattr(self, "_func_graph_flags"):
|
|
1833
2407
|
self._func_graph_flags = {}
|
|
1834
2408
|
self._func_graph_flags.update({**flags})
|
|
1835
|
-
if context._get_mode() == context.PYNATIVE_MODE and self._func_graph_flags.get("output_no_recompute"):
|
|
1836
|
-
raise TypeError("Recompute is not supported in PyNative mode currently, you can use "
|
|
1837
|
-
"'context.set_context(mode=context.GRAPH_MODE)' or @jit to set graph mode.")
|
|
1838
2409
|
self.__dict__.update({**flags})
|
|
1839
2410
|
self._add_mixed_precision_flag(**flags)
|
|
1840
2411
|
return self
|
|
@@ -1927,8 +2498,8 @@ class Cell(Cell_):
|
|
|
1927
2498
|
>>>
|
|
1928
2499
|
>>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
|
|
1929
2500
|
>>> net.to_float(mstype.float16)
|
|
1930
|
-
Conv2d
|
|
1931
|
-
padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=None, format=NCHW
|
|
2501
|
+
Conv2d(input_channels=120, output_channels=240, kernel_size=(4, 4), stride=(1, 1), pad_mode=same,
|
|
2502
|
+
padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=None, format=NCHW)
|
|
1932
2503
|
"""
|
|
1933
2504
|
if dst_type not in (mstype.float16, mstype.float32, mstype.bfloat16):
|
|
1934
2505
|
raise ValueError("For 'to_float', the argument 'dst_type' must be mstype.float32, mstype.float16 or "
|
|
@@ -2133,8 +2704,7 @@ class Cell(Cell_):
|
|
|
2133
2704
|
"""
|
|
2134
2705
|
if context._get_mode() == context.GRAPH_MODE:
|
|
2135
2706
|
return HookHandle()
|
|
2136
|
-
|
|
2137
|
-
return HookHandle()
|
|
2707
|
+
check_hook_fn(hook_fn)
|
|
2138
2708
|
handle = HookHandle(self._forward_pre_hook)
|
|
2139
2709
|
self._forward_pre_hook[handle.handle_id] = hook_fn
|
|
2140
2710
|
return handle
|
|
@@ -2233,8 +2803,7 @@ class Cell(Cell_):
|
|
|
2233
2803
|
return HookHandle()
|
|
2234
2804
|
if context._get_mode() == context.GRAPH_MODE:
|
|
2235
2805
|
return HookHandle()
|
|
2236
|
-
|
|
2237
|
-
return HookHandle()
|
|
2806
|
+
check_hook_fn(hook_fn)
|
|
2238
2807
|
handle = HookHandle(self._forward_hook)
|
|
2239
2808
|
self._forward_hook[handle.handle_id] = hook_fn
|
|
2240
2809
|
return handle
|
|
@@ -2324,8 +2893,7 @@ class Cell(Cell_):
|
|
|
2324
2893
|
"""
|
|
2325
2894
|
if context._get_mode() == context.GRAPH_MODE:
|
|
2326
2895
|
return HookHandle()
|
|
2327
|
-
|
|
2328
|
-
return HookHandle()
|
|
2896
|
+
check_hook_fn(hook_fn)
|
|
2329
2897
|
handle = HookHandle(self._backward_pre_hook)
|
|
2330
2898
|
self._backward_pre_hook[handle.handle_id] = hook_fn
|
|
2331
2899
|
if self._cell_backward_pre_hook is None:
|
|
@@ -2361,6 +2929,527 @@ class Cell(Cell_):
|
|
|
2361
2929
|
len(ret), len(outputs)))
|
|
2362
2930
|
return ret
|
|
2363
2931
|
|
|
2932
|
+
def get_extra_state(self) -> Any:
|
|
2933
|
+
"""Return any extra state to include in the cell's state_dict.
|
|
2934
|
+
|
|
2935
|
+
This function is called from ``state_dict``.
|
|
2936
|
+
Implement this and a corresponding ``set_extra_state`` for your cell
|
|
2937
|
+
if you need to store extra state.
|
|
2938
|
+
|
|
2939
|
+
Note that extra state should be picklable to ensure working serialization
|
|
2940
|
+
of the state_dict. Only provide backwards compatibility guarantees
|
|
2941
|
+
for serializing tensors; other objects may break backwards compatibility if
|
|
2942
|
+
their serialized pickled form changes.
|
|
2943
|
+
|
|
2944
|
+
Returns:
|
|
2945
|
+
object, any extra state to store in the cell's state_dict.
|
|
2946
|
+
"""
|
|
2947
|
+
raise RuntimeError(
|
|
2948
|
+
"Reached a code path in Cell.get_extra_state() that should never be called."
|
|
2949
|
+
|
|
2950
|
+
)
|
|
2951
|
+
|
|
2952
|
+
def set_extra_state(self, state: Any) -> None:
|
|
2953
|
+
"""Set extra state contained in the loaded `state_dict`.
|
|
2954
|
+
|
|
2955
|
+
This function is called from `load_state_dict` to handle any extra state
|
|
2956
|
+
found within the `state_dict`. Implement this function and a corresponding
|
|
2957
|
+
`get_extra_state` for your cell if you need to store extra state within its
|
|
2958
|
+
`state_dict`.
|
|
2959
|
+
|
|
2960
|
+
Args:
|
|
2961
|
+
state (dict): Extra state from the `state_dict`.
|
|
2962
|
+
"""
|
|
2963
|
+
raise RuntimeError(
|
|
2964
|
+
"Reached a code path in Cell.set_extra_state() that should never be called."
|
|
2965
|
+
)
|
|
2966
|
+
|
|
2967
|
+
@jit_forbidden_register
|
|
2968
|
+
def register_state_dict_post_hook(self, hook):
|
|
2969
|
+
r"""Register a post-hook for the :func:`mindspore.nn.Cell.state_dict` method.
|
|
2970
|
+
|
|
2971
|
+
It should have the following signature:
|
|
2972
|
+
|
|
2973
|
+
hook(cell, state_dict, prefix, local_metadata) -> None
|
|
2974
|
+
|
|
2975
|
+
The registered hooks can modify the ``state_dict`` inplace.
|
|
2976
|
+
|
|
2977
|
+
Args:
|
|
2978
|
+
hook (Callable): The hook function after `state_dict` is called.
|
|
2979
|
+
|
|
2980
|
+
Returns:
|
|
2981
|
+
A handle that can be used to remove the added hook by calling
|
|
2982
|
+
`handle.remove()`.
|
|
2983
|
+
"""
|
|
2984
|
+
from mindspore.utils.hooks import _RemovableHandle
|
|
2985
|
+
handle = _RemovableHandle(self._state_dict_hooks)
|
|
2986
|
+
self._state_dict_hooks[handle.id] = hook
|
|
2987
|
+
return handle
|
|
2988
|
+
|
|
2989
|
+
@jit_forbidden_register
|
|
2990
|
+
def register_state_dict_pre_hook(self, hook):
|
|
2991
|
+
r"""Register a pre-hook for the :func:`mindspore.nn.Cell.state_dict` method.
|
|
2992
|
+
|
|
2993
|
+
It should have the following signature:
|
|
2994
|
+
|
|
2995
|
+
hook(cell, prefix, keep_vars) -> None
|
|
2996
|
+
|
|
2997
|
+
The registered hooks can be used to perform pre-processing before the `state_dict`
|
|
2998
|
+
call is made.
|
|
2999
|
+
|
|
3000
|
+
Args:
|
|
3001
|
+
hook (Callable): The hook function before `state_dict` is called.
|
|
3002
|
+
|
|
3003
|
+
Returns:
|
|
3004
|
+
A handle that can be used to remove the added hook by calling
|
|
3005
|
+
`handle.remove()`.
|
|
3006
|
+
|
|
3007
|
+
Examples:
|
|
3008
|
+
>>> import mindspore
|
|
3009
|
+
...
|
|
3010
|
+
...
|
|
3011
|
+
>>> class NetA(mindspore.nn.Cell):
|
|
3012
|
+
... def __init__(self):
|
|
3013
|
+
... super().__init__()
|
|
3014
|
+
... self.register_buffer("buffer_a", mindspore.tensor([1, 2, 3]))
|
|
3015
|
+
... self.param_a = mindspore.Parameter(mindspore.tensor([1, 2, 3]))
|
|
3016
|
+
...
|
|
3017
|
+
... def construct(self, x):
|
|
3018
|
+
... return x + self.buffer_a + self.param_a
|
|
3019
|
+
...
|
|
3020
|
+
...
|
|
3021
|
+
>>> def _add_extra_param(cell, prefix, keep_vars):
|
|
3022
|
+
... cell._params["extra_param"] = mindspore.Parameter(mindspore.tensor([4, 5, 6]))
|
|
3023
|
+
...
|
|
3024
|
+
...
|
|
3025
|
+
>>> net = NetA()
|
|
3026
|
+
>>> handle = net.register_state_dict_pre_hook(_add_extra_param)
|
|
3027
|
+
>>> net_state_dict = net.state_dict()
|
|
3028
|
+
>>> handle.remove()
|
|
3029
|
+
>>> print("extra_param" in net_state_dict)
|
|
3030
|
+
True
|
|
3031
|
+
"""
|
|
3032
|
+
from mindspore.utils.hooks import _RemovableHandle
|
|
3033
|
+
handle = _RemovableHandle(self._state_dict_pre_hooks)
|
|
3034
|
+
self._state_dict_pre_hooks[handle.id] = hook
|
|
3035
|
+
return handle
|
|
3036
|
+
|
|
3037
|
+
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
3038
|
+
r"""Save cell state to the `destination` dictionary.
|
|
3039
|
+
|
|
3040
|
+
The `destination` dictionary will contain the state
|
|
3041
|
+
of the cell, but not its descendants. This is called on every
|
|
3042
|
+
sub cell in :func:`mindspore.nn.Cell.state_dict`.
|
|
3043
|
+
|
|
3044
|
+
In rare cases, subclasses can achieve class-specific behavior by
|
|
3045
|
+
overriding this method with custom logic.
|
|
3046
|
+
|
|
3047
|
+
Args:
|
|
3048
|
+
destination (dict): a dict where state will be stored
|
|
3049
|
+
prefix (str): the prefix for parameters and buffers used in this
|
|
3050
|
+
cell
|
|
3051
|
+
"""
|
|
3052
|
+
for name, param in self._params.items():
|
|
3053
|
+
if param is not None:
|
|
3054
|
+
destination[prefix + name] = param
|
|
3055
|
+
for name, buf in self._buffers.items():
|
|
3056
|
+
if buf is not None and name not in self._non_persistent_buffers_set:
|
|
3057
|
+
destination[prefix + name] = buf
|
|
3058
|
+
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
|
3059
|
+
if (
|
|
3060
|
+
getattr(self.__class__, "get_extra_state", Cell.get_extra_state)
|
|
3061
|
+
is not Cell.get_extra_state
|
|
3062
|
+
):
|
|
3063
|
+
destination[extra_state_key] = self.get_extra_state()
|
|
3064
|
+
|
|
3065
|
+
# The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
|
|
3066
|
+
# back that same object. But if they pass nothing, an `OrderedDict` is created and returned.
|
|
3067
|
+
T_destination = TypeVar("T_destination", bound=Dict[str, Any])
|
|
3068
|
+
|
|
3069
|
+
@jit_forbidden_register
|
|
3070
|
+
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
|
3071
|
+
r"""Return a dictionary containing references to the whole state of the cell.
|
|
3072
|
+
|
|
3073
|
+
Both parameters and persistent buffers (e.g. running averages) are
|
|
3074
|
+
included. Keys are corresponding parameter and buffer names.
|
|
3075
|
+
Parameters and buffers set to ``None`` are not included.
|
|
3076
|
+
|
|
3077
|
+
.. note::
|
|
3078
|
+
The returned object is a shallow copy. It contains references
|
|
3079
|
+
to the cell's parameters and buffers.
|
|
3080
|
+
|
|
3081
|
+
.. warning::
|
|
3082
|
+
- Currently ``state_dict()`` also accepts positional arguments for
|
|
3083
|
+
``destination``, ``prefix`` and ``keep_vars`` in order. However,
|
|
3084
|
+
this is being deprecated and keyword arguments will be enforced in
|
|
3085
|
+
future releases.
|
|
3086
|
+
|
|
3087
|
+
- Please avoid the use of argument ``destination`` as it is not
|
|
3088
|
+
designed for end-users.
|
|
3089
|
+
|
|
3090
|
+
Args:
|
|
3091
|
+
destination (dict, optional): If provided, the state of cell will
|
|
3092
|
+
be updated into the dict and the same object is returned.
|
|
3093
|
+
Otherwise, an ``OrderedDict`` will be created and returned.
|
|
3094
|
+
Default: ``None``.
|
|
3095
|
+
prefix (str, optional): A prefix added to parameter and buffer
|
|
3096
|
+
names to compose the keys in state_dict. Default: ``''``.
|
|
3097
|
+
keep_vars (bool, optional): Whether the state_dict returns a copy. Default: ``False`` , returns a reference.
|
|
3098
|
+
|
|
3099
|
+
Returns:
|
|
3100
|
+
Dict, a dictionary containing a whole state of the cell.
|
|
3101
|
+
|
|
3102
|
+
Examples:
|
|
3103
|
+
>>> import mindspore
|
|
3104
|
+
>>> class Model(mindspore.nn.Cell):
|
|
3105
|
+
... def __init__(self):
|
|
3106
|
+
... super().__init__()
|
|
3107
|
+
... self.register_buffer("buffer_a", mindspore.tensor([4, 5, 6]))
|
|
3108
|
+
... self.param_a = mindspore.Parameter(mindspore.tensor([1, 2, 3]))
|
|
3109
|
+
...
|
|
3110
|
+
... def construct(self, x):
|
|
3111
|
+
... return x + self.buffer_a + self.param_a
|
|
3112
|
+
...
|
|
3113
|
+
...
|
|
3114
|
+
>>> model = Model()
|
|
3115
|
+
>>> print(model.state_dict())
|
|
3116
|
+
OrderedDict([('param_a', Parameter (name=param_a, shape=(3,), dtype=Int64, requires_grad=True)), \
|
|
3117
|
+
('buffer_a', Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]))])
|
|
3118
|
+
"""
|
|
3119
|
+
# TODO: Remove `args` and the parsing logic when BC allows.
|
|
3120
|
+
if args:
|
|
3121
|
+
# DeprecationWarning is ignored by default
|
|
3122
|
+
warnings.warn(
|
|
3123
|
+
"Positional args are being deprecated, use kwargs instead. Refer to "
|
|
3124
|
+
"https://www.mindspore.cn/docs/zh-CN/master/api_python/nn/mindspore.nn.Cell.html"
|
|
3125
|
+
" for details.",
|
|
3126
|
+
FutureWarning,
|
|
3127
|
+
stacklevel=2,
|
|
3128
|
+
)
|
|
3129
|
+
if destination is None:
|
|
3130
|
+
destination = args[0]
|
|
3131
|
+
if len(args) > 1 and prefix == "":
|
|
3132
|
+
prefix = args[1]
|
|
3133
|
+
if len(args) > 2 and keep_vars is False:
|
|
3134
|
+
keep_vars = args[2]
|
|
3135
|
+
if destination is not None and not isinstance(destination, dict):
|
|
3136
|
+
raise TypeError(f"The type of destination must be OrderedDict, but got {type(destination)}")
|
|
3137
|
+
if not isinstance(prefix, str):
|
|
3138
|
+
raise TypeError(f"The type of prefix must be string, but got {type(prefix)}")
|
|
3139
|
+
if not isinstance(keep_vars, bool):
|
|
3140
|
+
raise TypeError(f"The type of keep_vars must be bool, but got {type(keep_vars)}")
|
|
3141
|
+
|
|
3142
|
+
if destination is None:
|
|
3143
|
+
destination = OrderedDict()
|
|
3144
|
+
destination._metadata = OrderedDict()
|
|
3145
|
+
|
|
3146
|
+
local_metadata = {}
|
|
3147
|
+
if hasattr(destination, "_metadata"):
|
|
3148
|
+
destination._metadata[prefix[:-1]] = local_metadata
|
|
3149
|
+
|
|
3150
|
+
for hook in self._state_dict_pre_hooks.values():
|
|
3151
|
+
hook(self, prefix, keep_vars)
|
|
3152
|
+
self._save_to_state_dict(destination, prefix, keep_vars)
|
|
3153
|
+
for name, cell in self._cells.items():
|
|
3154
|
+
if cell is not None:
|
|
3155
|
+
cell.state_dict(
|
|
3156
|
+
destination=destination,
|
|
3157
|
+
prefix=prefix + name + ".",
|
|
3158
|
+
keep_vars=keep_vars,
|
|
3159
|
+
)
|
|
3160
|
+
for hook in self._state_dict_hooks.values():
|
|
3161
|
+
hook_result = hook(self, destination, prefix, local_metadata)
|
|
3162
|
+
if hook_result is not None:
|
|
3163
|
+
raise RuntimeError("state_dict post-hook must return None")
|
|
3164
|
+
return destination
|
|
3165
|
+
|
|
3166
|
+
@jit_forbidden_register
|
|
3167
|
+
def register_load_state_dict_pre_hook(self, hook):
|
|
3168
|
+
r"""Register a pre-hook to be run before cell's :func:`mindspore.nn.Cell.load_state_dict` is called.
|
|
3169
|
+
|
|
3170
|
+
It should have the following signature:
|
|
3171
|
+
|
|
3172
|
+
hook(cell, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950
|
|
3173
|
+
|
|
3174
|
+
Args:
|
|
3175
|
+
hook (Callable): The hook function before `load_state_dict` is called.
|
|
3176
|
+
|
|
3177
|
+
Returns:
|
|
3178
|
+
A handle that can be used to remove the added hook by calling
|
|
3179
|
+
`handle.remove()`.
|
|
3180
|
+
"""
|
|
3181
|
+
from mindspore.utils.hooks import _RemovableHandle
|
|
3182
|
+
handle = _RemovableHandle(self._load_state_dict_pre_hooks)
|
|
3183
|
+
self._load_state_dict_pre_hooks[handle.id] = hook
|
|
3184
|
+
return handle
|
|
3185
|
+
|
|
3186
|
+
@jit_forbidden_register
|
|
3187
|
+
def register_load_state_dict_post_hook(self, hook):
|
|
3188
|
+
r"""Register a post-hook to be run after cell's :func:`mindspore.nn.Cell.load_state_dict` is called.
|
|
3189
|
+
|
|
3190
|
+
It should have the following signature:
|
|
3191
|
+
|
|
3192
|
+
hook(cell, incompatible_keys) -> None
|
|
3193
|
+
|
|
3194
|
+
The ``cell`` argument is the current cell that this hook is registered
|
|
3195
|
+
on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting
|
|
3196
|
+
of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys``
|
|
3197
|
+
is a ``list`` of ``str`` containing the missing keys and
|
|
3198
|
+
``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys.
|
|
3199
|
+
|
|
3200
|
+
The given incompatible_keys can be modified inplace if needed.
|
|
3201
|
+
|
|
3202
|
+
Note that the checks performed when calling :func:`load_state_dict` with
|
|
3203
|
+
``strict=True`` are affected by modifications the hook makes to
|
|
3204
|
+
``missing_keys`` or ``unexpected_keys``, as expected. Additions to either
|
|
3205
|
+
set of keys will result in an error being thrown when ``strict=True``, and
|
|
3206
|
+
clearing out both missing and unexpected keys will avoid an error.
|
|
3207
|
+
|
|
3208
|
+
Args:
|
|
3209
|
+
hook (Callable): The hook function after `load_state_dict` is called.
|
|
3210
|
+
|
|
3211
|
+
Returns:
|
|
3212
|
+
A handle that can be used to remove the added hook by calling
|
|
3213
|
+
`handle.remove()`.
|
|
3214
|
+
"""
|
|
3215
|
+
from mindspore.utils.hooks import _RemovableHandle
|
|
3216
|
+
handle = _RemovableHandle(self._load_state_dict_post_hooks)
|
|
3217
|
+
self._load_state_dict_post_hooks[handle.id] = hook
|
|
3218
|
+
return handle
|
|
3219
|
+
|
|
3220
|
+
def _load_from_state_dict(
|
|
3221
|
+
self,
|
|
3222
|
+
state_dict,
|
|
3223
|
+
prefix,
|
|
3224
|
+
local_metadata,
|
|
3225
|
+
strict,
|
|
3226
|
+
missing_keys,
|
|
3227
|
+
unexpected_keys,
|
|
3228
|
+
error_msgs,
|
|
3229
|
+
):
|
|
3230
|
+
r"""Copy parameters and buffers from :attr:`state_dict` into only this cell, but not its descendants.
|
|
3231
|
+
|
|
3232
|
+
This is called on every sub cell
|
|
3233
|
+
in :func:`mindspore.nn.Cell.load_state_dict`. Metadata saved for this
|
|
3234
|
+
cell in input :attr:`state_dict` is provided as :attr:`local_metadata`.
|
|
3235
|
+
For state dicts without metadata, :attr:`local_metadata` is empty.
|
|
3236
|
+
Subclasses can achieve class-specific backward compatible loading using
|
|
3237
|
+
the version number at `local_metadata.get("version", None)`.
|
|
3238
|
+
|
|
3239
|
+
.. note::
|
|
3240
|
+
:attr:`state_dict` is not the same object as the input
|
|
3241
|
+
:attr:`state_dict` to :func:`mindspore.nn.Cell.load_state_dict`. So
|
|
3242
|
+
it can be modified.
|
|
3243
|
+
|
|
3244
|
+
Args:
|
|
3245
|
+
state_dict (dict): a dict containing parameters and
|
|
3246
|
+
persistent buffers.
|
|
3247
|
+
prefix (str): the prefix for parameters and buffers used in this
|
|
3248
|
+
cell
|
|
3249
|
+
local_metadata (dict): a dict containing the metadata for this cell.
|
|
3250
|
+
See
|
|
3251
|
+
strict (bool): whether to strictly enforce that the keys in
|
|
3252
|
+
:attr:`state_dict` with :attr:`prefix` match the names of
|
|
3253
|
+
parameters and buffers in this cell
|
|
3254
|
+
missing_keys (list of str): if ``strict=True``, add missing keys to
|
|
3255
|
+
this list
|
|
3256
|
+
unexpected_keys (list of str): if ``strict=True``, add unexpected
|
|
3257
|
+
keys to this list
|
|
3258
|
+
error_msgs (list of str): error messages should be added to this
|
|
3259
|
+
list, and will be reported together in
|
|
3260
|
+
:func:`mindspore.nn.Cell.load_state_dict`
|
|
3261
|
+
"""
|
|
3262
|
+
for hook in self._load_state_dict_pre_hooks.values():
|
|
3263
|
+
hook(
|
|
3264
|
+
self,
|
|
3265
|
+
state_dict,
|
|
3266
|
+
prefix,
|
|
3267
|
+
local_metadata,
|
|
3268
|
+
strict,
|
|
3269
|
+
missing_keys,
|
|
3270
|
+
unexpected_keys,
|
|
3271
|
+
error_msgs,
|
|
3272
|
+
)
|
|
3273
|
+
|
|
3274
|
+
persistent_buffers = {
|
|
3275
|
+
k: v
|
|
3276
|
+
for k, v in self._buffers.items()
|
|
3277
|
+
if k not in self._non_persistent_buffers_set
|
|
3278
|
+
}
|
|
3279
|
+
local_name_params = itertools.chain(
|
|
3280
|
+
self._params.items(), persistent_buffers.items()
|
|
3281
|
+
)
|
|
3282
|
+
local_state = {k: v for k, v in local_name_params if v is not None}
|
|
3283
|
+
|
|
3284
|
+
for name, param in local_state.items():
|
|
3285
|
+
key = prefix + name
|
|
3286
|
+
if key in state_dict:
|
|
3287
|
+
input_param = state_dict[key]
|
|
3288
|
+
if not isinstance(input_param, Tensor):
|
|
3289
|
+
error_msgs.append(
|
|
3290
|
+
f'While copying the parameter named "{key}", '
|
|
3291
|
+
"expected Tensor or Tensor-like object from checkpoint but "
|
|
3292
|
+
f"received {type(input_param)}"
|
|
3293
|
+
)
|
|
3294
|
+
continue
|
|
3295
|
+
|
|
3296
|
+
if input_param.shape != param.shape:
|
|
3297
|
+
# local shape should match the one in checkpoint
|
|
3298
|
+
error_msgs.append(
|
|
3299
|
+
f"size mismatch for {key}: copying a param with shape {input_param.shape} from checkpoint, "
|
|
3300
|
+
f"the shape in current model is {param.shape}."
|
|
3301
|
+
)
|
|
3302
|
+
continue
|
|
3303
|
+
try:
|
|
3304
|
+
param.assign_value(Tensor(input_param.asnumpy(), dtype=param.dtype))
|
|
3305
|
+
except Exception as ex: # pylint: disable=W0703
|
|
3306
|
+
error_msgs.append(
|
|
3307
|
+
f'While copy the parameter named "{key}", '
|
|
3308
|
+
f"whose shape in the model are {param.shape} and "
|
|
3309
|
+
f"whose shape in the checkpoint are {input_param.shape}, "
|
|
3310
|
+
f"an exception occurred : {ex.args}."
|
|
3311
|
+
)
|
|
3312
|
+
elif strict:
|
|
3313
|
+
missing_keys.append(key)
|
|
3314
|
+
|
|
3315
|
+
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
|
3316
|
+
if getattr(self.__class__, "set_extra_state", Cell.set_extra_state) is not Cell.set_extra_state:
|
|
3317
|
+
if extra_state_key in state_dict:
|
|
3318
|
+
self.set_extra_state(state_dict[extra_state_key])
|
|
3319
|
+
elif strict:
|
|
3320
|
+
missing_keys.append(extra_state_key)
|
|
3321
|
+
elif strict and (extra_state_key in state_dict):
|
|
3322
|
+
unexpected_keys.append(extra_state_key)
|
|
3323
|
+
|
|
3324
|
+
if strict:
|
|
3325
|
+
for key in state_dict.keys():
|
|
3326
|
+
if key.startswith(prefix) and key != extra_state_key:
|
|
3327
|
+
input_name = key[len(prefix):].split(".", 1)
|
|
3328
|
+
# Must be cell if it have attributes
|
|
3329
|
+
if len(input_name) > 1:
|
|
3330
|
+
if input_name[0] not in self._cells:
|
|
3331
|
+
unexpected_keys.append(key)
|
|
3332
|
+
elif input_name[0] not in local_state:
|
|
3333
|
+
unexpected_keys.append(key)
|
|
3334
|
+
|
|
3335
|
+
@jit_forbidden_register
|
|
3336
|
+
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
|
3337
|
+
r"""Copy parameters and buffers from :attr:`state_dict` into this cell and its descendants.
|
|
3338
|
+
|
|
3339
|
+
If :attr:`strict` is ``True``, then
|
|
3340
|
+
the keys of :attr:`state_dict` must exactly match the keys returned
|
|
3341
|
+
by this cell's :func:`mindspore.nn.Cell.state_dict` function.
|
|
3342
|
+
|
|
3343
|
+
Args:
|
|
3344
|
+
state_dict (dict): A dict containing parameters and
|
|
3345
|
+
persistent buffers.
|
|
3346
|
+
strict (bool, optional): Whether to strictly enforce that the keys
|
|
3347
|
+
in input `state_dict` match the keys returned by this cell's
|
|
3348
|
+
:func:`mindspore.nn.Cell.state_dict` function. Default ``True`` .
|
|
3349
|
+
|
|
3350
|
+
Returns:
|
|
3351
|
+
A namedtuple with ``missing_keys`` and ``unexpected_keys`` fields,
|
|
3352
|
+
|
|
3353
|
+
- `missing_keys` is a list of str containing any keys that are expected
|
|
3354
|
+
by this cell but missing from the provided ``state_dict``.
|
|
3355
|
+
|
|
3356
|
+
- `unexpected_keys` is a list of str containing the keys that are not
|
|
3357
|
+
expected by this cell but present in the provided ``state_dict``.
|
|
3358
|
+
|
|
3359
|
+
Note:
|
|
3360
|
+
If `strict` is ``True`` and a parameter or buffer is registered as ``None``, but its corresponding key
|
|
3361
|
+
exists in :attr:`state_dict`, and :func:`mindspore.nn.Cell.load_state_dict` will raise a ``RuntimeError``.
|
|
3362
|
+
|
|
3363
|
+
Examples:
|
|
3364
|
+
>>> import mindspore
|
|
3365
|
+
>>> import os
|
|
3366
|
+
>>> class Model(mindspore.nn.Cell):
|
|
3367
|
+
... def __init__(self):
|
|
3368
|
+
... super().__init__()
|
|
3369
|
+
... self.register_buffer("buffer_a", mindspore.tensor([4, 5, 6]))
|
|
3370
|
+
... self.param_a = mindspore.Parameter(mindspore.tensor([1, 2, 3]))
|
|
3371
|
+
...
|
|
3372
|
+
... def construct(self, x):
|
|
3373
|
+
... return x + self.buffer_a + self.param_a
|
|
3374
|
+
...
|
|
3375
|
+
...
|
|
3376
|
+
>>> model = Model()
|
|
3377
|
+
>>> print(model.state_dict())
|
|
3378
|
+
>>> mindspore.save_checkpoint(model.state_dict(), './model_state_dict_ckpt')
|
|
3379
|
+
>>> new_model = Model()
|
|
3380
|
+
>>> new_model.load_state_dict(mindspore.load_checkpoint('./model_state_dict_ckpt'))
|
|
3381
|
+
>>> print(new_model.state_dict())
|
|
3382
|
+
>>> os.remove('./model_state_dict_ckpt')
|
|
3383
|
+
OrderedDict([('param_a', Parameter (name=param_a, shape=(3,), dtype=Int64, requires_grad=True)), \
|
|
3384
|
+
('buffer_a', Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]))])
|
|
3385
|
+
OrderedDict([('param_a', Parameter (name=param_a, shape=(3,), dtype=Int64, requires_grad=True)), \
|
|
3386
|
+
('buffer_a', Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]))])
|
|
3387
|
+
"""
|
|
3388
|
+
if not isinstance(state_dict, Mapping):
|
|
3389
|
+
raise TypeError(
|
|
3390
|
+
f"Expected state_dict to be dict-like, got {type(state_dict)}."
|
|
3391
|
+
)
|
|
3392
|
+
|
|
3393
|
+
missing_keys: List[str] = []
|
|
3394
|
+
unexpected_keys: List[str] = []
|
|
3395
|
+
error_msgs: List[str] = []
|
|
3396
|
+
|
|
3397
|
+
# copy state_dict so _load_from_state_dict can modify it
|
|
3398
|
+
metadata = getattr(state_dict, "_metadata", None)
|
|
3399
|
+
state_dict = OrderedDict(state_dict)
|
|
3400
|
+
if metadata is not None:
|
|
3401
|
+
# mypy isn't aware that "_metadata" exists in state_dict
|
|
3402
|
+
state_dict._metadata = metadata # type: ignore[attr-defined]
|
|
3403
|
+
|
|
3404
|
+
def load(cell, local_state_dict, prefix=""):
|
|
3405
|
+
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
|
3406
|
+
cell._load_from_state_dict(
|
|
3407
|
+
local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs,
|
|
3408
|
+
)
|
|
3409
|
+
for name, child in cell._cells.items():
|
|
3410
|
+
if child is not None:
|
|
3411
|
+
child_prefix = prefix + name + "."
|
|
3412
|
+
child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
|
|
3413
|
+
load(child, child_state_dict, child_prefix) # noqa: F821
|
|
3414
|
+
|
|
3415
|
+
# Note that the hook can modify missing_keys and unexpected_keys.
|
|
3416
|
+
incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
|
|
3417
|
+
for hook in cell._load_state_dict_post_hooks.values():
|
|
3418
|
+
out = hook(cell, incompatible_keys)
|
|
3419
|
+
if out is not None:
|
|
3420
|
+
raise RuntimeError(
|
|
3421
|
+
"Hooks registered with ``register_load_state_dict_post_hook`` are not"
|
|
3422
|
+
"expected to return new values, if incompatible_keys need to be modified,"
|
|
3423
|
+
"it should be done inplace."
|
|
3424
|
+
)
|
|
3425
|
+
|
|
3426
|
+
load(self, state_dict)
|
|
3427
|
+
del load
|
|
3428
|
+
|
|
3429
|
+
if strict:
|
|
3430
|
+
if unexpected_keys:
|
|
3431
|
+
error_msgs.insert(
|
|
3432
|
+
0,
|
|
3433
|
+
"Unexpected key(s) in state_dict: {}. ".format(
|
|
3434
|
+
", ".join(f'"{k}"' for k in unexpected_keys)
|
|
3435
|
+
),
|
|
3436
|
+
)
|
|
3437
|
+
if missing_keys:
|
|
3438
|
+
error_msgs.insert(
|
|
3439
|
+
0,
|
|
3440
|
+
"Missing key(s) in state_dict: {}. ".format(
|
|
3441
|
+
", ".join(f'"{k}"' for k in missing_keys)
|
|
3442
|
+
),
|
|
3443
|
+
)
|
|
3444
|
+
|
|
3445
|
+
if error_msgs:
|
|
3446
|
+
raise RuntimeError(
|
|
3447
|
+
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
|
3448
|
+
self.__class__.__name__, "\n\t".join(error_msgs)
|
|
3449
|
+
)
|
|
3450
|
+
)
|
|
3451
|
+
return _IncompatibleKeys(missing_keys, unexpected_keys)
|
|
3452
|
+
|
|
2364
3453
|
def register_backward_hook(self, hook_fn):
|
|
2365
3454
|
"""
|
|
2366
3455
|
Register the backward hook function.
|
|
@@ -2420,8 +3509,7 @@ class Cell(Cell_):
|
|
|
2420
3509
|
"""
|
|
2421
3510
|
if context._get_mode() == context.GRAPH_MODE:
|
|
2422
3511
|
return HookHandle()
|
|
2423
|
-
|
|
2424
|
-
return HookHandle()
|
|
3512
|
+
check_hook_fn(hook_fn)
|
|
2425
3513
|
handle = HookHandle(self._backward_hook)
|
|
2426
3514
|
self._backward_hook[handle.handle_id] = hook_fn
|
|
2427
3515
|
if self._cell_backward_hook is None:
|
|
@@ -2565,8 +3653,9 @@ class Cell(Cell_):
|
|
|
2565
3653
|
if not self._has_config_recompute:
|
|
2566
3654
|
self._has_config_recompute = True
|
|
2567
3655
|
else:
|
|
2568
|
-
|
|
2569
|
-
|
|
3656
|
+
logger.info("The recompute interface can be configured only once."
|
|
3657
|
+
" When the parent cell is configured, the child cell should not be configured")
|
|
3658
|
+
return
|
|
2570
3659
|
self._set_recompute_scope(mode)
|
|
2571
3660
|
if mode and not output_recompute:
|
|
2572
3661
|
self.add_flags(output_no_recompute=True)
|
|
@@ -2606,18 +3695,13 @@ class Cell(Cell_):
|
|
|
2606
3695
|
"""
|
|
2607
3696
|
if context.get_context("mode") == context.PYNATIVE_MODE:
|
|
2608
3697
|
self._recompute_cell = recompute_registry.get()(self.construct)
|
|
2609
|
-
self._add_recompute_flag()
|
|
2610
|
-
return
|
|
2611
3698
|
self._recompute()
|
|
2612
3699
|
if 'mp_comm_recompute' in kwargs.keys():
|
|
2613
3700
|
self._mp_comm_recompute(kwargs.get('mp_comm_recompute', False))
|
|
2614
3701
|
if 'parallel_optimizer_comm_recompute' in kwargs.keys():
|
|
2615
|
-
if
|
|
2616
|
-
context.get_auto_parallel_context("pipeline_stages") > 1):
|
|
3702
|
+
if kwargs.get('parallel_optimizer_comm_recompute', False):
|
|
2617
3703
|
logger.warning("Currently, the communication operator allgathers introduced by optimizer shard "
|
|
2618
|
-
"
|
|
2619
|
-
elif context.get_auto_parallel_context("pipeline_stages") == 1:
|
|
2620
|
-
self._parallel_optimizer_comm_recompute(kwargs.get('parallel_optimizer_comm_recompute', False))
|
|
3704
|
+
"is replaced with zero3.")
|
|
2621
3705
|
if 'recompute_slice_activation' in kwargs:
|
|
2622
3706
|
self._recompute_slice_activation(kwargs.get('recompute_slice_activation', False))
|
|
2623
3707
|
|
|
@@ -2709,18 +3793,6 @@ class Cell(Cell_):
|
|
|
2709
3793
|
if hasattr(network, "_amp_level"):
|
|
2710
3794
|
self._amp_level = getattr(network, "_amp_level")
|
|
2711
3795
|
|
|
2712
|
-
def _add_recompute_flag(self):
|
|
2713
|
-
"""
|
|
2714
|
-
Set pynative cell recomputed.
|
|
2715
|
-
"""
|
|
2716
|
-
if not self._has_config_recompute:
|
|
2717
|
-
self._has_config_recompute = True
|
|
2718
|
-
else:
|
|
2719
|
-
logger.info("The recompute interface can be configured only once."
|
|
2720
|
-
" If the parent cell is configured, the child cell should not be configured")
|
|
2721
|
-
for cell in self.cells():
|
|
2722
|
-
cell._add_recompute_flag()
|
|
2723
|
-
|
|
2724
3796
|
def _register_parameters_hook(self, forward_hook=None, backward_hook=None, all=False):
|
|
2725
3797
|
"""
|
|
2726
3798
|
Register the forward hook for parameters and register the backward hook for the corresponding gradient.
|
|
@@ -2807,6 +3879,7 @@ class Cell(Cell_):
|
|
|
2807
3879
|
cell._parameters_forward_hook = forward_hook
|
|
2808
3880
|
cell._parameters_backward_hook = backward_hook
|
|
2809
3881
|
|
|
3882
|
+
|
|
2810
3883
|
class GraphCell(Cell):
|
|
2811
3884
|
"""
|
|
2812
3885
|
Base class for running the graph loaded from a MindIR.
|
|
@@ -2820,12 +3893,10 @@ class GraphCell(Cell):
|
|
|
2820
3893
|
The key is the parameter name whose type is str, and the value is a Tensor or Parameter.
|
|
2821
3894
|
If the parameter exists in the graph according to the name, update it's value.
|
|
2822
3895
|
If the parameter does not exist, ignore it. Default: ``None`` .
|
|
2823
|
-
obf_random_seed (Union[int, None]): The random seed used for dynamic obfuscation
|
|
2824
|
-
used for model protection, which can refer to :func:`mindspore.obfuscate_model`. If the input `graph` is
|
|
2825
|
-
a func_graph loaded from a mindir file obfuscated with `obf_random_seed` , then `obf_random_seed` should be
|
|
2826
|
-
provided. `obf_random_seed` should be in (0, 9223372036854775807]. default: ``None`` .
|
|
3896
|
+
obf_random_seed (Union[int, None]): The random seed used for dynamic obfuscation, which is not supported now.
|
|
2827
3897
|
|
|
2828
3898
|
Raises:
|
|
3899
|
+
NotImplementedError: Dynamic structure obfuscation is not supported now.
|
|
2829
3900
|
TypeError: If the `graph` is not a FuncGraph.
|
|
2830
3901
|
TypeError: If the `params_init` is not a dict.
|
|
2831
3902
|
TypeError: If the key of the `params_init` is not a str.
|
|
@@ -2855,20 +3926,12 @@ class GraphCell(Cell):
|
|
|
2855
3926
|
|
|
2856
3927
|
def __init__(self, graph, params_init=None, obf_random_seed=None):
|
|
2857
3928
|
super(GraphCell, self).__init__(auto_prefix=True)
|
|
3929
|
+
if obf_random_seed is not None:
|
|
3930
|
+
raise NotImplementedError("Dynamic structure obfuscation is not supported now.")
|
|
2858
3931
|
if not isinstance(graph, FuncGraph):
|
|
2859
3932
|
raise TypeError(f"For 'GraphCell', the argument 'graph' must be a FuncGraph loaded from MindIR, "
|
|
2860
3933
|
f"but got type {type(graph)}.")
|
|
2861
3934
|
self.graph = graph
|
|
2862
|
-
self.obf_random_seed = obf_random_seed
|
|
2863
|
-
if obf_random_seed is not None:
|
|
2864
|
-
if not isinstance(obf_random_seed, int):
|
|
2865
|
-
raise TypeError("'obf_random_seed' must be int, but got {}.".format(type(obf_random_seed)))
|
|
2866
|
-
int_64_max = 9223372036854775807
|
|
2867
|
-
if obf_random_seed <= 0 or obf_random_seed > int_64_max:
|
|
2868
|
-
raise ValueError(
|
|
2869
|
-
"'obf_random_seed' must be larger than 0, and less or equal than int64 ({}),"
|
|
2870
|
-
"but got {}.".format(int_64_max, obf_random_seed))
|
|
2871
|
-
self._branch_control_input = _generate_branch_control_input(self.obf_random_seed)
|
|
2872
3935
|
params_init = {} if params_init is None else params_init
|
|
2873
3936
|
if not isinstance(params_init, dict):
|
|
2874
3937
|
raise TypeError(f"For 'GraphCell', the argument 'params_init' must be a dict, but got {type(params_init)}.")
|
|
@@ -2888,19 +3951,30 @@ class GraphCell(Cell):
|
|
|
2888
3951
|
def __call__(self, *args, **kwargs):
|
|
2889
3952
|
self.phase = "graph_load_from_mindir"
|
|
2890
3953
|
self._add_attr("graph_load_from_mindir", self.graph)
|
|
2891
|
-
|
|
2892
|
-
return self.compile_and_run(*args, **kwargs)
|
|
2893
|
-
append_input = Tensor((numpy.ones((1,)) * self._branch_control_input).astype(numpy.int32))
|
|
2894
|
-
return self.compile_and_run(*args, append_input, **kwargs)
|
|
3954
|
+
return self.compile_and_run(*args, **kwargs)
|
|
2895
3955
|
|
|
2896
3956
|
|
|
2897
|
-
def
|
|
3957
|
+
def _is_parameter_list_or_tuple(value):
|
|
2898
3958
|
"""
|
|
2899
3959
|
Check the type of input in list or tuple is Parameter.
|
|
2900
3960
|
:param value: list or tuple.
|
|
2901
3961
|
:return: The types of all inputs are parameter.
|
|
2902
3962
|
"""
|
|
2903
|
-
|
|
2904
|
-
|
|
2905
|
-
|
|
2906
|
-
|
|
3963
|
+
if isinstance(value, (list, tuple)) and value:
|
|
3964
|
+
for item in value:
|
|
3965
|
+
if not isinstance(item, Parameter):
|
|
3966
|
+
return False
|
|
3967
|
+
return True
|
|
3968
|
+
return False
|
|
3969
|
+
|
|
3970
|
+
|
|
3971
|
+
def _addindent(s_, num_spaces):
|
|
3972
|
+
s = s_.split("\n")
|
|
3973
|
+
# don't do anything for single-line stuff
|
|
3974
|
+
if len(s) == 1:
|
|
3975
|
+
return s_
|
|
3976
|
+
first = s.pop(0)
|
|
3977
|
+
s = [(num_spaces * " ") + line for line in s]
|
|
3978
|
+
s = "\n".join(s)
|
|
3979
|
+
s = first + "\n" + s
|
|
3980
|
+
return s
|