mindspore 2.2.14__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +8 -5
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +124 -25
- mindspore/_extends/builtin_operations.py +2 -1
- mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
- mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
- mindspore/_extends/parse/__init__.py +18 -14
- mindspore/_extends/parse/compile_config.py +299 -0
- mindspore/_extends/parse/namespace.py +2 -2
- mindspore/_extends/parse/parser.py +182 -68
- mindspore/_extends/parse/resources.py +45 -14
- mindspore/_extends/parse/standard_method.py +192 -252
- mindspore/{ops/_op_impl/tbe/atomic_addr_clean.py → _extends/pijit/__init__.py} +6 -16
- mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
- mindspore/_extends/remote/kernel_build_server.py +2 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +67 -26
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/base.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/boost/grad_freeze.py +2 -2
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/boost/less_batch_normalization.py +9 -6
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +20 -7
- mindspore/common/_jit_fallback_utils.py +2 -3
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_adapter.py +7 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +10 -10
- mindspore/common/_stub_tensor.py +7 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/_utils.py +5 -17
- mindspore/common/api.py +449 -129
- mindspore/common/auto_dynamic_shape.py +27 -14
- mindspore/common/dtype.py +17 -10
- mindspore/common/dump.py +8 -11
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +254 -0
- mindspore/common/hook_handle.py +65 -30
- mindspore/common/initializer.py +1 -1
- mindspore/common/jit_config.py +34 -14
- mindspore/common/lazy_inline.py +72 -19
- mindspore/common/mindir_util.py +12 -2
- mindspore/common/mutable.py +79 -14
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +73 -21
- mindspore/common/recompute.py +292 -0
- mindspore/common/seed.py +9 -9
- mindspore/common/sparse_tensor.py +276 -24
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +668 -514
- mindspore/communication/__init__.py +6 -11
- mindspore/communication/_comm_helper.py +43 -3
- mindspore/communication/comm_func.py +1395 -0
- mindspore/communication/management.py +117 -104
- mindspore/config/op_info.config +22 -54
- mindspore/context.py +455 -71
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +6 -6
- mindspore/dataset/audio/transforms.py +711 -158
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +2 -2
- mindspore/dataset/engine/datasets.py +201 -116
- mindspore/dataset/engine/datasets_audio.py +14 -14
- mindspore/dataset/engine/datasets_standard_format.py +83 -3
- mindspore/dataset/engine/datasets_text.py +39 -39
- mindspore/dataset/engine/datasets_user_defined.py +230 -141
- mindspore/dataset/engine/datasets_vision.py +78 -74
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/offload.py +5 -7
- mindspore/dataset/engine/queue.py +138 -66
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +41 -15
- mindspore/dataset/text/__init__.py +2 -5
- mindspore/dataset/text/transforms.py +408 -121
- mindspore/dataset/text/utils.py +9 -9
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/transforms/transforms.py +261 -76
- mindspore/dataset/utils/browse_dataset.py +9 -9
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +7 -10
- mindspore/dataset/vision/c_transforms.py +10 -10
- mindspore/dataset/vision/py_transforms_util.py +1 -1
- mindspore/dataset/vision/transforms.py +2844 -549
- mindspore/dataset/vision/utils.py +161 -10
- mindspore/dataset/vision/validators.py +16 -3
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/{rewrite/ast_creator_register.py → experimental/es/__init__.py} +5 -20
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/experimental/es/embedding_service_layer.py +581 -0
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/experimental/llm_boost/atb/__init__.py +23 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/__init__.py +12 -2
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +35 -34
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +124 -15
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +66 -121
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +18 -8
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +28 -19
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +357 -0
- mindspore/include/api/data_type.h +2 -2
- mindspore/include/api/dual_abi_helper.h +16 -3
- mindspore/include/api/model.h +4 -3
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/status.h +14 -0
- mindspore/include/api/types.h +10 -10
- mindspore/include/c_api/model_c.h +173 -0
- mindspore/include/c_api/types_c.h +19 -0
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +3 -5
- mindspore/include/dataset/vision.h +58 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +3 -3
- mindspore/mindrecord/__init__.py +5 -1
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +25 -0
- mindspore/mindrecord/filewriter.py +138 -103
- mindspore/mindrecord/mindpage.py +40 -6
- mindspore/mindrecord/shardutils.py +3 -2
- mindspore/mindrecord/shardwriter.py +7 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +8 -13
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -15
- mindspore/mindrecord/tools/csv_to_mr.py +4 -9
- mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
- mindspore/mindrecord/tools/mnist_to_mr.py +7 -12
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +1586 -0
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/{rewrite/ast_transformers → mint/linalg}/__init__.py +9 -4
- mindspore/mint/nn/__init__.py +757 -0
- mindspore/mint/nn/functional.py +679 -0
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +206 -0
- mindspore/mint/special/__init__.py +63 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +73 -0
- mindspore/nn/cell.py +461 -323
- mindspore/nn/dynamic_lr.py +2 -2
- mindspore/nn/layer/activation.py +292 -135
- mindspore/nn/layer/basic.py +288 -83
- mindspore/nn/layer/channel_shuffle.py +3 -16
- mindspore/nn/layer/container.py +3 -3
- mindspore/nn/layer/conv.py +75 -66
- mindspore/nn/layer/embedding.py +221 -45
- mindspore/nn/layer/image.py +4 -7
- mindspore/nn/layer/math.py +1 -1
- mindspore/nn/layer/normalization.py +150 -68
- mindspore/nn/layer/padding.py +64 -87
- mindspore/nn/layer/pooling.py +175 -12
- mindspore/nn/layer/rnn_cells.py +6 -16
- mindspore/nn/layer/rnns.py +6 -5
- mindspore/nn/layer/thor_layer.py +1 -2
- mindspore/nn/layer/timedistributed.py +1 -1
- mindspore/nn/layer/transformer.py +55 -53
- mindspore/nn/learning_rate_schedule.py +6 -5
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +145 -88
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/ada_grad.py +4 -2
- mindspore/nn/optim/adadelta.py +4 -2
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +102 -181
- mindspore/nn/optim/adamax.py +4 -2
- mindspore/nn/optim/adasum.py +3 -3
- mindspore/nn/optim/asgd.py +4 -2
- mindspore/nn/optim/ftrl.py +31 -61
- mindspore/nn/optim/lamb.py +5 -3
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +6 -4
- mindspore/nn/optim/momentum.py +13 -25
- mindspore/nn/optim/optimizer.py +6 -3
- mindspore/nn/optim/proximal_ada_grad.py +4 -2
- mindspore/nn/optim/rmsprop.py +9 -3
- mindspore/nn/optim/rprop.py +4 -2
- mindspore/nn/optim/sgd.py +5 -3
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
- mindspore/nn/probability/distribution/beta.py +2 -2
- mindspore/nn/probability/distribution/categorical.py +4 -6
- mindspore/nn/probability/distribution/cauchy.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +2 -2
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +13 -1
- mindspore/nn/wrap/__init__.py +2 -1
- mindspore/nn/wrap/cell_wrapper.py +46 -12
- mindspore/nn/wrap/grad_reducer.py +148 -8
- mindspore/nn/wrap/loss_scale.py +44 -7
- mindspore/numpy/__init__.py +2 -0
- mindspore/numpy/array_creations.py +67 -68
- mindspore/numpy/array_ops.py +70 -66
- mindspore/numpy/dtypes.py +3 -3
- mindspore/numpy/fft.py +966 -0
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +147 -152
- mindspore/numpy/utils.py +3 -0
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +9 -6
- mindspore/ops/_grad_experimental/grad_array_ops.py +4 -129
- mindspore/ops/_grad_experimental/grad_comm_ops.py +135 -36
- mindspore/ops/_grad_experimental/grad_math_ops.py +61 -298
- mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
- mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
- mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/__init__.py +0 -1
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +1 -1
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
- mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -3
- mindspore/ops/_op_impl/cpu/adam.py +2 -2
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
- mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
- mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
- mindspore/ops/_vmap/vmap_array_ops.py +162 -101
- mindspore/ops/_vmap/vmap_base.py +8 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +143 -58
- mindspore/ops/_vmap/vmap_image_ops.py +70 -13
- mindspore/ops/_vmap/vmap_math_ops.py +147 -59
- mindspore/ops/_vmap/vmap_nn_ops.py +292 -117
- mindspore/ops/_vmap/vmap_other_ops.py +1 -1
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
- mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +201 -66
- mindspore/ops/composite/math_ops.py +10 -49
- mindspore/ops/composite/multitype_ops/_compile_utils.py +192 -618
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +25 -134
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +8 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
- mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
- mindspore/ops/deprecated.py +14 -3
- mindspore/ops/function/__init__.py +53 -11
- mindspore/ops/function/array_func.py +1269 -1821
- mindspore/ops/function/clip_func.py +19 -31
- mindspore/ops/function/debug_func.py +114 -5
- mindspore/ops/function/fft_func.py +44 -0
- mindspore/ops/function/grad/grad_func.py +30 -22
- mindspore/ops/function/image_func.py +27 -21
- mindspore/ops/function/linalg_func.py +35 -68
- mindspore/ops/function/math_func.py +1170 -2697
- mindspore/ops/function/nn_func.py +2116 -1128
- mindspore/ops/function/other_func.py +8 -8
- mindspore/ops/function/parameter_func.py +5 -93
- mindspore/ops/function/random_func.py +435 -113
- mindspore/ops/function/reshard_func.py +104 -0
- mindspore/ops/function/sparse_func.py +4 -4
- mindspore/ops/function/sparse_unary_func.py +9 -16
- mindspore/ops/function/spectral_func.py +1 -1
- mindspore/ops/function/vmap_func.py +16 -15
- mindspore/ops/functional.py +355 -346
- mindspore/ops/op_info_register.py +18 -45
- mindspore/ops/operations/__init__.py +38 -24
- mindspore/ops/operations/_grad_ops.py +21 -927
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +173 -607
- mindspore/ops/operations/_rl_inner_ops.py +2 -2
- mindspore/ops/operations/_scalar_ops.py +5 -480
- mindspore/ops/operations/_sequence_ops.py +6 -36
- mindspore/ops/operations/_tensor_array.py +8 -8
- mindspore/ops/operations/array_ops.py +106 -2837
- mindspore/ops/operations/comm_ops.py +799 -127
- mindspore/ops/operations/custom_ops.py +124 -119
- mindspore/ops/operations/debug_ops.py +142 -41
- mindspore/ops/operations/image_ops.py +1 -217
- mindspore/ops/operations/inner_ops.py +5 -40
- mindspore/ops/operations/linalg_ops.py +1 -49
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +73 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
- mindspore/ops/operations/math_ops.py +666 -4972
- mindspore/ops/operations/nn_ops.py +205 -2213
- mindspore/ops/operations/other_ops.py +60 -49
- mindspore/ops/operations/random_ops.py +50 -54
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/sparse_ops.py +4 -4
- mindspore/ops/primitive.py +216 -103
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +252 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +1099 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +1052 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +145 -0
- mindspore/ops_generate/pyboost_utils.py +367 -0
- mindspore/ops_generate/template.py +261 -0
- mindspore/parallel/__init__.py +8 -4
- mindspore/parallel/_auto_parallel_context.py +100 -10
- mindspore/parallel/_cell_wrapper.py +99 -9
- mindspore/parallel/_cost_model_context.py +1 -1
- mindspore/parallel/_dp_allreduce_fusion.py +159 -159
- mindspore/parallel/_parallel_serialization.py +67 -23
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +99 -22
- mindspore/parallel/_transformer/__init__.py +1 -1
- mindspore/parallel/_transformer/layers.py +1 -1
- mindspore/parallel/_transformer/loss.py +1 -1
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/op_parallel_config.py +1 -1
- mindspore/parallel/_transformer/transformer.py +2 -2
- mindspore/parallel/_utils.py +173 -6
- mindspore/parallel/algo_parameter_config.py +8 -10
- mindspore/parallel/checkpoint_transform.py +204 -38
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +352 -0
- mindspore/parallel/cluster/process_entity/_utils.py +101 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +1 -1
- mindspore/parallel/mpi/_mpi_config.py +1 -1
- mindspore/parallel/parameter_broadcast.py +151 -0
- mindspore/parallel/shard.py +279 -37
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +4 -2
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/util.py +153 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +18 -20
- mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
- mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +14 -9
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -1
- mindspore/profiler/parser/ascend_flops_generator.py +20 -4
- mindspore/profiler/parser/ascend_hccl_generator.py +29 -278
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +148 -146
- mindspore/profiler/parser/ascend_msprof_generator.py +73 -283
- mindspore/profiler/parser/ascend_op_generator.py +92 -42
- mindspore/profiler/parser/ascend_timeline_generator.py +298 -133
- mindspore/profiler/parser/base_timeline_generator.py +25 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +4 -393
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/integrator.py +3 -1
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/minddata_parser.py +72 -3
- mindspore/profiler/parser/profiler_info.py +94 -7
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +631 -508
- mindspore/rewrite/__init__.py +2 -14
- mindspore/rewrite/api/node.py +122 -36
- mindspore/rewrite/api/pattern_engine.py +2 -3
- mindspore/rewrite/api/scoped_value.py +16 -15
- mindspore/rewrite/api/symbol_tree.py +45 -29
- mindspore/rewrite/ast_helpers/__init__.py +3 -6
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
- mindspore/rewrite/common/__init__.py +1 -2
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
- mindspore/rewrite/{namer.py → common/namer.py} +63 -18
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/node/__init__.py +5 -5
- mindspore/rewrite/node/call_function.py +23 -7
- mindspore/rewrite/node/cell_container.py +7 -3
- mindspore/rewrite/node/control_flow.py +53 -28
- mindspore/rewrite/node/node.py +212 -196
- mindspore/rewrite/node/node_manager.py +51 -22
- mindspore/rewrite/node/node_topological_manager.py +3 -23
- mindspore/rewrite/parsers/__init__.py +12 -0
- mindspore/rewrite/parsers/arguments_parser.py +8 -9
- mindspore/rewrite/parsers/assign_parser.py +637 -413
- mindspore/rewrite/parsers/attribute_parser.py +3 -4
- mindspore/rewrite/parsers/class_def_parser.py +115 -148
- mindspore/rewrite/parsers/constant_parser.py +5 -5
- mindspore/rewrite/parsers/container_parser.py +4 -6
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +31 -98
- mindspore/rewrite/parsers/function_def_parser.py +13 -5
- mindspore/rewrite/parsers/if_parser.py +28 -10
- mindspore/rewrite/parsers/module_parser.py +8 -182
- mindspore/rewrite/parsers/parser.py +1 -5
- mindspore/rewrite/parsers/parser_register.py +1 -1
- mindspore/rewrite/parsers/return_parser.py +5 -10
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +705 -186
- mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
- mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
- mindspore/run_check/_check_version.py +40 -115
- mindspore/run_check/run_check.py +1 -1
- mindspore/safeguard/rewrite_obfuscation.py +597 -263
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +7 -5
- mindspore/train/_utils.py +204 -4
- mindspore/train/amp.py +335 -295
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/__init__.py +5 -2
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +220 -43
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_flops_collector.py +239 -0
- mindspore/train/callback/_landscape.py +15 -9
- mindspore/train/callback/_loss_monitor.py +5 -5
- mindspore/train/callback/_on_request_exit.py +136 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
- mindspore/train/callback/_summary_collector.py +12 -12
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/data_sink.py +6 -5
- mindspore/train/dataset_helper.py +66 -23
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/accuracy.py +7 -7
- mindspore/train/metrics/confusion_matrix.py +8 -6
- mindspore/train/metrics/cosine_similarity.py +6 -4
- mindspore/train/metrics/error.py +2 -2
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/perplexity.py +2 -1
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/metrics/topk.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +382 -76
- mindspore/train/serialization.py +787 -288
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +51 -28
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/METADATA +8 -4
- mindspore-2.4.0.dist-info/RECORD +1406 -0
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +1 -0
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
- mindspore/gen_ops.py +0 -273
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -282
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/layer/flash_attention.py +0 -189
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/_op_impl/cpu/concat.py +0 -39
- mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
- mindspore/ops/_op_impl/tbe/__init__.py +0 -47
- mindspore/ops/_op_impl/tbe/abs.py +0 -38
- mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/acos.py +0 -37
- mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/acosh.py +0 -37
- mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
- mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
- mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
- mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
- mindspore/ops/_op_impl/tbe/add.py +0 -42
- mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/add_n.py +0 -39
- mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
- mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
- mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
- mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
- mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
- mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
- mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
- mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/asin.py +0 -37
- mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/asinh.py +0 -37
- mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/assign.py +0 -79
- mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
- mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
- mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/atan.py +0 -37
- mindspore/ops/_op_impl/tbe/atan2.py +0 -38
- mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/atanh.py +0 -37
- mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
- mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
- mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
- mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
- mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
- mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
- mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
- mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
- mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
- mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
- mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cast.py +0 -55
- mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/cdist.py +0 -38
- mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/ceil.py +0 -37
- mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/celu.py +0 -39
- mindspore/ops/_op_impl/tbe/centralization.py +0 -39
- mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
- mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/concat.py +0 -40
- mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
- mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
- mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
- mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
- mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
- mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/cos.py +0 -37
- mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/cosh.py +0 -37
- mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
- mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cummin.py +0 -41
- mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
- mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
- mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
- mindspore/ops/_op_impl/tbe/diag.py +0 -38
- mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
- mindspore/ops/_op_impl/tbe/dilation.py +0 -40
- mindspore/ops/_op_impl/tbe/div.py +0 -41
- mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
- mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
- mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
- mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
- mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
- mindspore/ops/_op_impl/tbe/elu.py +0 -38
- mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/equal.py +0 -42
- mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/erf.py +0 -37
- mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfc.py +0 -37
- mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
- mindspore/ops/_op_impl/tbe/exp.py +0 -40
- mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
- mindspore/ops/_op_impl/tbe/expm1.py +0 -37
- mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
- mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/fill.py +0 -56
- mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/flatten.py +0 -48
- mindspore/ops/_op_impl/tbe/floor.py +0 -37
- mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
- mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
- mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
- mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
- mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
- mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
- mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
- mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/ger.py +0 -43
- mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/greater.py +0 -43
- mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
- mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
- mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
- mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
- mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
- mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
- mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
- mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/im2col.py +0 -42
- mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
- mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
- mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/inv.py +0 -38
- mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/invert.py +0 -37
- mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/iou.py +0 -38
- mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/is_close.py +0 -40
- mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
- mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
- mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
- mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
- mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
- mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
- mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/lerp.py +0 -38
- mindspore/ops/_op_impl/tbe/less.py +0 -41
- mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/log.py +0 -40
- mindspore/ops/_op_impl/tbe/log1p.py +0 -37
- mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
- mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
- mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
- mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
- mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/matmul.py +0 -53
- mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
- mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
- mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
- mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum.py +0 -39
- mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
- mindspore/ops/_op_impl/tbe/minimum.py +0 -40
- mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mish.py +0 -37
- mindspore/ops/_op_impl/tbe/mod.py +0 -41
- mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/mul.py +0 -37
- mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
- mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
- mindspore/ops/_op_impl/tbe/neg.py +0 -39
- mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
- mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
- mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
- mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
- mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
- mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/pack.py +0 -58
- mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
- mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
- mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/pdist.py +0 -36
- mindspore/ops/_op_impl/tbe/pooling.py +0 -46
- mindspore/ops/_op_impl/tbe/population_count.py +0 -38
- mindspore/ops/_op_impl/tbe/pow.py +0 -41
- mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/prelu.py +0 -37
- mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/range.py +0 -39
- mindspore/ops/_op_impl/tbe/real_div.py +0 -38
- mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
- mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
- mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
- mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
- mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6.py +0 -38
- mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/renorm.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
- mindspore/ops/_op_impl/tbe/rint.py +0 -37
- mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roll.py +0 -42
- mindspore/ops/_op_impl/tbe/round.py +0 -38
- mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
- mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
- mindspore/ops/_op_impl/tbe/select.py +0 -38
- mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/selu.py +0 -39
- mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sgd.py +0 -62
- mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sign.py +0 -38
- mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/sin.py +0 -37
- mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sinh.py +0 -37
- mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/slice.py +0 -58
- mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
- mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax.py +0 -37
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
- mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/softplus.py +0 -37
- mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softsign.py +0 -37
- mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sort.py +0 -38
- mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/split_d.py +0 -38
- mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/split_v.py +0 -39
- mindspore/ops/_op_impl/tbe/splitv.py +0 -39
- mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/square.py +0 -38
- mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
- mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
- mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
- mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
- mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
- mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
- mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
- mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
- mindspore/ops/_op_impl/tbe/sub.py +0 -39
- mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tan.py +0 -38
- mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh.py +0 -37
- mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
- mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
- mindspore/ops/_op_impl/tbe/tile.py +0 -37
- mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
- mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
- mindspore/ops/_op_impl/tbe/transpose.py +0 -60
- mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
- mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
- mindspore/ops/_op_impl/tbe/trunc.py +0 -39
- mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/unpack.py +0 -38
- mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
- mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
- mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
- mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
- mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
- mindspore/ops/_tracefunc.py +0 -241
- mindspore/ops/arg_dtype_cast.py +0 -54
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/rewrite/api/tree_node_helper.py +0 -60
- mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
- mindspore/rewrite/namespace.py +0 -53
- mindspore-2.2.14.dist-info/RECORD +0 -1924
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2020-
|
|
1
|
+
# Copyright 2020-2024 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -20,10 +20,9 @@ import inspect
|
|
|
20
20
|
import os
|
|
21
21
|
import time
|
|
22
22
|
from collections import OrderedDict
|
|
23
|
-
from types import FunctionType, MethodType
|
|
24
23
|
import numpy
|
|
25
24
|
|
|
26
|
-
from mindspore._checkparam import args_type_check
|
|
25
|
+
from mindspore._checkparam import args_type_check, check_hook_fn
|
|
27
26
|
from mindspore.common._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic
|
|
28
27
|
from mindspore import log as logger
|
|
29
28
|
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
|
|
@@ -33,8 +32,9 @@ from mindspore import context
|
|
|
33
32
|
from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
|
|
34
33
|
from mindspore import _checkparam as Validator
|
|
35
34
|
from mindspore.common import dtype as mstype
|
|
36
|
-
from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache
|
|
37
|
-
from mindspore.common.api import _generate_branch_control_input
|
|
35
|
+
from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache, _no_grad
|
|
36
|
+
from mindspore.common.api import _generate_branch_control_input, _convert_python_data, _get_args_for_run_predict
|
|
37
|
+
from mindspore.common.api import _process_dyn_args, _generate_dyn_compile_args
|
|
38
38
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
39
39
|
from mindspore.common.tensor import Tensor
|
|
40
40
|
from mindspore.ops.operations import Cast
|
|
@@ -43,16 +43,7 @@ from mindspore.ops.operations import _inner_ops as inner
|
|
|
43
43
|
from mindspore.parallel.shard import Shard
|
|
44
44
|
from mindspore._check_jit_forbidden_api import jit_forbidden_register
|
|
45
45
|
from mindspore.common._decorator import deprecated
|
|
46
|
-
from mindspore.
|
|
47
|
-
from mindspore.ops._tracefunc import _convert_tensor, _SetMixedPrecision, PackFunc
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
def _check_args(args):
|
|
51
|
-
"""Check the input args's type"""
|
|
52
|
-
for item in args:
|
|
53
|
-
if isinstance(item, Tensor) and item.has_init:
|
|
54
|
-
item.init_data()
|
|
55
|
-
|
|
46
|
+
from mindspore.common._register_for_recompute import recompute_registry
|
|
56
47
|
|
|
57
48
|
class Cell(Cell_):
|
|
58
49
|
"""
|
|
@@ -89,7 +80,7 @@ class Cell(Cell_):
|
|
|
89
80
|
|
|
90
81
|
Examples:
|
|
91
82
|
>>> import mindspore.nn as nn
|
|
92
|
-
>>>
|
|
83
|
+
>>> from mindspore import ops
|
|
93
84
|
>>> class MyCell(nn.Cell):
|
|
94
85
|
... def __init__(self, forward_net):
|
|
95
86
|
... super(MyCell, self).__init__(auto_prefix=False)
|
|
@@ -109,17 +100,19 @@ class Cell(Cell_):
|
|
|
109
100
|
"""
|
|
110
101
|
|
|
111
102
|
IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_create_time',
|
|
112
|
-
'_func_graph_flags', '_parameter_layout_dict', '_params_list', '
|
|
113
|
-
'_forward_pre_hook', '_forward_hook', '
|
|
114
|
-
'
|
|
103
|
+
'_func_graph_flags', '_parameter_layout_dict', '_params_list', '_phase', '_bprop_debug',
|
|
104
|
+
'_forward_pre_hook', '_forward_hook', '_backward_pre_hook', '_backward_hook',
|
|
105
|
+
'_cell_backward_pre_hook', '_cell_backward_hook', '_is_run', '_param_prefix',
|
|
115
106
|
'_attr_synced', 'pynative', 'requires_grad', 'cell_type']
|
|
107
|
+
total_instance_count = 0
|
|
116
108
|
|
|
117
109
|
def __init__(self, auto_prefix=True, flags=None):
|
|
118
110
|
Cell_.__init__(self, self._cell_tag)
|
|
111
|
+
Cell.total_instance_count += 1
|
|
112
|
+
self.instance_count = Cell.total_instance_count
|
|
119
113
|
self._params = OrderedDict()
|
|
120
114
|
self._cells = OrderedDict()
|
|
121
115
|
self._params_list = OrderedDict()
|
|
122
|
-
self._tensor_list = OrderedDict()
|
|
123
116
|
self._primitives = OrderedDict()
|
|
124
117
|
self.training = False
|
|
125
118
|
self.requires_grad = False
|
|
@@ -135,11 +128,15 @@ class Cell(Cell_):
|
|
|
135
128
|
self._create_time = int(time.time() * 1e9)
|
|
136
129
|
self.arguments_key = ""
|
|
137
130
|
self.compile_cache = set()
|
|
131
|
+
self.phase_cache = dict()
|
|
138
132
|
cells_compile_cache[id(self)] = self.compile_cache
|
|
139
133
|
self.parameter_broadcast_done = False
|
|
140
134
|
self._id = 1
|
|
141
135
|
self.exist_names = set("")
|
|
142
136
|
self.exist_objs = set()
|
|
137
|
+
self._recompute_cell = None
|
|
138
|
+
self.mixed_precision_type = None
|
|
139
|
+
self.sig = inspect.signature(self.construct)
|
|
143
140
|
init_pipeline()
|
|
144
141
|
|
|
145
142
|
# call gc to release GE session resources used by non-used cell objects
|
|
@@ -149,24 +146,33 @@ class Cell(Cell_):
|
|
|
149
146
|
if flags:
|
|
150
147
|
self.add_flags(**flags)
|
|
151
148
|
self._bprop_debug = False
|
|
149
|
+
|
|
150
|
+
# hook
|
|
152
151
|
self._forward_pre_hook = OrderedDict()
|
|
153
152
|
self._forward_hook = OrderedDict()
|
|
154
|
-
self.
|
|
155
|
-
self.
|
|
156
|
-
self.
|
|
153
|
+
self._backward_pre_hook = OrderedDict()
|
|
154
|
+
self._cell_backward_pre_hook = None
|
|
155
|
+
self._backward_hook = OrderedDict()
|
|
157
156
|
self._cell_backward_hook = None
|
|
158
157
|
self._is_recursion_hook = False
|
|
158
|
+
|
|
159
159
|
self.cell_type = None
|
|
160
160
|
self.cast = Cast()
|
|
161
161
|
self._has_config_recompute = False
|
|
162
162
|
self._user_parameters = []
|
|
163
163
|
self._dynamic_shape_inputs = None
|
|
164
|
+
self._compile_args = None
|
|
164
165
|
self.saved_dynamic_shape = None
|
|
165
166
|
self._jit_config_dict = dict()
|
|
166
167
|
self.grad_ops_label = False
|
|
167
168
|
self.ge_sync_data = False
|
|
168
169
|
self._is_check_and_refresh = False
|
|
169
170
|
self._amp_level = ""
|
|
171
|
+
self._init_flag = False
|
|
172
|
+
self._shard_fn = None
|
|
173
|
+
self.has_bprop = False
|
|
174
|
+
if hasattr(self, "bprop"):
|
|
175
|
+
self.has_bprop = True
|
|
170
176
|
|
|
171
177
|
def __getstate__(self):
|
|
172
178
|
base = Cell_.__getstate__(self)
|
|
@@ -224,8 +230,9 @@ class Cell(Cell_):
|
|
|
224
230
|
Get whether cell custom bprop debug is enabled.
|
|
225
231
|
|
|
226
232
|
Tutorial Examples:
|
|
227
|
-
- `
|
|
228
|
-
<https://mindspore.cn/
|
|
233
|
+
- `Custom Neural Network Layers - Custom Cell Reverse
|
|
234
|
+
<https://mindspore.cn/docs/en/master/model_train/custom_program/network_custom.html
|
|
235
|
+
#custom-cell-reverse>`_
|
|
229
236
|
"""
|
|
230
237
|
return self._bprop_debug
|
|
231
238
|
|
|
@@ -317,10 +324,23 @@ class Cell(Cell_):
|
|
|
317
324
|
|
|
318
325
|
@property
|
|
319
326
|
def pipeline_stage(self):
|
|
327
|
+
"""
|
|
328
|
+
`pipeline_stage` represents the pipeline stage of current Cell.
|
|
329
|
+
"""
|
|
320
330
|
return self._pipeline_stage
|
|
321
331
|
|
|
322
332
|
@pipeline_stage.setter
|
|
323
333
|
def pipeline_stage(self, value):
|
|
334
|
+
"""
|
|
335
|
+
Set the `pipeline_stage` of a Cell.
|
|
336
|
+
|
|
337
|
+
Args:
|
|
338
|
+
value (int): The pipeline stage of a parameter.
|
|
339
|
+
|
|
340
|
+
Raises:
|
|
341
|
+
TypeError: If `value` is not int type or is a bool type.
|
|
342
|
+
ValueError: If `value` is not a positive integer.
|
|
343
|
+
"""
|
|
324
344
|
if not isinstance(value, int) or isinstance(value, bool):
|
|
325
345
|
raise TypeError("For 'Cell', the property 'pipeline_stage' "
|
|
326
346
|
"must be int type, but got type : {}".format(type(value)))
|
|
@@ -362,6 +382,10 @@ class Cell(Cell_):
|
|
|
362
382
|
def jit_config_dict(self):
|
|
363
383
|
return self._jit_config_dict
|
|
364
384
|
|
|
385
|
+
@property
|
|
386
|
+
def enable_backward_hook(self):
|
|
387
|
+
return self._enable_backward_hook
|
|
388
|
+
|
|
365
389
|
def get_func_graph_proto(self):
|
|
366
390
|
"""Return graph binary proto."""
|
|
367
391
|
exec_id = ".".join([self.phase, str(self.create_time), str(id(self))])
|
|
@@ -376,10 +400,6 @@ class Cell(Cell_):
|
|
|
376
400
|
cells = self.__dict__['_cells']
|
|
377
401
|
if name in cells:
|
|
378
402
|
return cells[name]
|
|
379
|
-
if '_tensor_list' in self.__dict__:
|
|
380
|
-
tensor_list = self.__dict__['_tensor_list']
|
|
381
|
-
if name in tensor_list:
|
|
382
|
-
return tensor_list[name]
|
|
383
403
|
if '_params_list' in self.__dict__:
|
|
384
404
|
params_list = self.__dict__['_params_list']
|
|
385
405
|
if name in params_list:
|
|
@@ -391,12 +411,9 @@ class Cell(Cell_):
|
|
|
391
411
|
# while deepcopy a cell instance, the copied cell instance can't be added to cells_compile_cache
|
|
392
412
|
# here using pop(id(self), None) to avoid KeyError exception
|
|
393
413
|
cells_compile_cache.pop(id(self), None)
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
except AttributeError as e:
|
|
398
|
-
raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
|
|
399
|
-
f"Please use 'super().__init__()'.") from e
|
|
414
|
+
if hasattr(self, "compile_cache") and self.compile_cache:
|
|
415
|
+
_cell_graph_executor.del_net_res(self, self.compile_cache)
|
|
416
|
+
Cell.total_instance_count -= 1
|
|
400
417
|
|
|
401
418
|
def __delattr__(self, name):
|
|
402
419
|
if name in self._params:
|
|
@@ -405,8 +422,6 @@ class Cell(Cell_):
|
|
|
405
422
|
del self._cells[name]
|
|
406
423
|
elif '_params_list' in self.__dict__ and name in self._params_list:
|
|
407
424
|
del self._params_list[name]
|
|
408
|
-
elif '_tensor_list' in self.__dict__ and name in self._tensor_list:
|
|
409
|
-
del self._tensor_list[name]
|
|
410
425
|
else:
|
|
411
426
|
object.__delattr__(self, name)
|
|
412
427
|
self._attr_synced = False
|
|
@@ -420,7 +435,7 @@ class Cell(Cell_):
|
|
|
420
435
|
elif isinstance(item, float):
|
|
421
436
|
res.append(self.cast(item, dst_type))
|
|
422
437
|
elif hasattr(item, "dtype") and item.dtype in \
|
|
423
|
-
|
|
438
|
+
{mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16} and item.dtype != dst_type:
|
|
424
439
|
res.append(self.cast(item, dst_type))
|
|
425
440
|
else:
|
|
426
441
|
res.append(item)
|
|
@@ -470,18 +485,28 @@ class Cell(Cell_):
|
|
|
470
485
|
output = self._run_construct(cast_inputs, kwargs)
|
|
471
486
|
return output
|
|
472
487
|
|
|
473
|
-
def _run_construct(self,
|
|
488
|
+
def _run_construct(self, *inputs, **kwargs):
|
|
474
489
|
"""Run the construct function"""
|
|
475
|
-
if self.
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
490
|
+
if self._forward_pre_hook:
|
|
491
|
+
inputs = self._run_forward_pre_hook(inputs)
|
|
492
|
+
|
|
493
|
+
if self._backward_hook:
|
|
494
|
+
output = self._backward_hook_construct(*inputs, **kwargs)
|
|
495
|
+
elif self._shard_fn is not None:
|
|
496
|
+
output = self._shard_fn(*inputs, **kwargs)
|
|
497
|
+
elif self._recompute_cell is not None:
|
|
498
|
+
output = self._recompute_cell(*inputs, **kwargs)
|
|
499
|
+
elif self.has_bprop and _pynative_executor.requires_grad():
|
|
500
|
+
output = self._call_custom_bprop(*inputs, **kwargs)
|
|
481
501
|
else:
|
|
482
|
-
output = self.construct(*
|
|
483
|
-
|
|
484
|
-
|
|
502
|
+
output = self.construct(*inputs, **kwargs)
|
|
503
|
+
|
|
504
|
+
if self._forward_hook:
|
|
505
|
+
output = self._run_forward_hook(inputs, output)
|
|
506
|
+
|
|
507
|
+
if self._backward_pre_hook:
|
|
508
|
+
output = self._run_backward_pre_hook(output)
|
|
509
|
+
|
|
485
510
|
return output
|
|
486
511
|
|
|
487
512
|
def _check_construct_args(self, *args):
|
|
@@ -519,7 +544,7 @@ class Cell(Cell_):
|
|
|
519
544
|
'''Hook function in graph mode'''
|
|
520
545
|
# Check super().__init__() in graph mode.
|
|
521
546
|
try:
|
|
522
|
-
if self.
|
|
547
|
+
if self._forward_pre_hook or self._forward_hook or self._backward_pre_hook or self._backward_hook:
|
|
523
548
|
return True
|
|
524
549
|
except AttributeError as e:
|
|
525
550
|
raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
|
|
@@ -566,22 +591,22 @@ class Cell(Cell_):
|
|
|
566
591
|
def shard(self, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
|
|
567
592
|
"""
|
|
568
593
|
Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be
|
|
569
|
-
generated by sharding propagation. In PyNative mode, use this method
|
|
570
|
-
to specify
|
|
594
|
+
generated by sharding propagation. In PyNative mode, use this method to specify a Cell for distributed
|
|
595
|
+
execution in graph mode. In Graph mode, use this method to specify distribution strategy for a Cell,
|
|
596
|
+
strategy for others will be set by sharding propagation.
|
|
571
597
|
in_strategy and out_strategy define the input and output layout respectively.
|
|
572
598
|
in_strategy/out_strategy should be a tuple, each element of which corresponds to the desired layout of
|
|
573
|
-
this input/output,
|
|
574
|
-
which can refer to the description of `mindspore.ops.Primitive.shard`.
|
|
599
|
+
this input/output, which can refer to the description of `mindspore.ops.Primitive.shard`.
|
|
575
600
|
The parallel strategies of remaining operators are derived from the strategy specified by the input and output.
|
|
576
601
|
|
|
577
602
|
Note:
|
|
578
|
-
|
|
579
|
-
|
|
603
|
+
If Cell.shard is called, the parallel mode in `set_auto_parallel_context` (parallel_mode) will be set to
|
|
604
|
+
"auto_parallel" and the search mode (search_mode) to "sharding_propagation".
|
|
580
605
|
If the input contain Parameter, its strategy should be set in `in_strategy`.
|
|
581
606
|
|
|
582
607
|
Args:
|
|
583
|
-
in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple
|
|
584
|
-
|
|
608
|
+
in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple. Tuple
|
|
609
|
+
defines the layout of the corresponding input.
|
|
585
610
|
out_strategy (Union[None, tuple]): Define the layout of outputs similar with in_strategy.
|
|
586
611
|
It is not in use right now. Default: ``None`` .
|
|
587
612
|
parameter_plan (Union[dict, None]): Define the layout for the specified parameters. Each element in dict
|
|
@@ -598,7 +623,7 @@ class Cell(Cell_):
|
|
|
598
623
|
use right now. Support [ ``"0"`` , ``"1"`` , ``"2"`` ]. Default: ``0`` .
|
|
599
624
|
|
|
600
625
|
Returns:
|
|
601
|
-
|
|
626
|
+
Function, return the cell construct function that will be executed under auto parallel process.
|
|
602
627
|
|
|
603
628
|
Examples:
|
|
604
629
|
>>> import mindspore.nn as nn
|
|
@@ -616,22 +641,21 @@ class Cell(Cell_):
|
|
|
616
641
|
... def __init__(self):
|
|
617
642
|
... self.block1 = Block()
|
|
618
643
|
... self.block2 = Block()
|
|
619
|
-
... self.block2.shard(in_strategy=((2, 1),),
|
|
620
|
-
...
|
|
644
|
+
... self.block2_shard = self.block2.shard(in_strategy=((2, 1),),
|
|
645
|
+
... parameter_plan={'self.block2.shard.dense1.weight': (4, 1)})
|
|
621
646
|
... def construct(self, x):
|
|
622
647
|
... x = self.block1(x)
|
|
623
|
-
... x = self.
|
|
648
|
+
... x = self.block2_shard(x)
|
|
624
649
|
... return x
|
|
625
650
|
"""
|
|
626
|
-
if context.
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
f"Please check if you call Cell.shard in the script.")
|
|
651
|
+
if context.get_auto_parallel_context("parallel_mode") not in ["auto_parallel", "semi_auto_parallel"]:
|
|
652
|
+
raise AssertionError(f"Cell shard only supports auto parallel or semi_auto_parallel "
|
|
653
|
+
f"Please check the parallel mode in parallel context.")
|
|
630
654
|
|
|
631
655
|
shard_fn = Shard()
|
|
632
656
|
fn = shard_fn(self, in_strategy, out_strategy, parameter_plan, device, level)
|
|
633
|
-
|
|
634
|
-
return
|
|
657
|
+
self._shard_fn = fn
|
|
658
|
+
return fn
|
|
635
659
|
|
|
636
660
|
def auto_cast_inputs(self, inputs):
|
|
637
661
|
"""
|
|
@@ -654,65 +678,113 @@ class Cell(Cell_):
|
|
|
654
678
|
|
|
655
679
|
return cast_inputs
|
|
656
680
|
|
|
657
|
-
def
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
bound_arguments = inspect.signature(self.construct).bind(*args, **kwargs)
|
|
663
|
-
bound_arguments.apply_defaults()
|
|
664
|
-
args = bound_arguments.args
|
|
665
|
-
kwargs = bound_arguments.kwargs
|
|
681
|
+
def _init_check(self):
|
|
682
|
+
for param in self.get_parameters(expand=False):
|
|
683
|
+
if param.has_init:
|
|
684
|
+
param.init_data()
|
|
685
|
+
self._init_flag = True
|
|
666
686
|
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
if hasattr(self, '_is_check_and_refresh') and not self._is_check_and_refresh:
|
|
687
|
+
def _self_check(self):
|
|
688
|
+
if not self._is_check_and_refresh:
|
|
671
689
|
self.check_names_and_refresh_name()
|
|
672
690
|
self._is_check_and_refresh = True
|
|
673
691
|
|
|
692
|
+
def _predict(self, *args, **kwargs):
|
|
693
|
+
if not hasattr(self, "phase"):
|
|
694
|
+
return False, None
|
|
695
|
+
if (self.phase == "prefill" or self.phase == 'increment') and self.phase in self.phase_cache:
|
|
696
|
+
new_args = _get_args_for_run_predict(self, args, kwargs, self._compile_args)
|
|
697
|
+
res = _cell_graph_executor._graph_executor(tuple(new_args), self.phase_cache[self.phase])
|
|
698
|
+
res = _convert_python_data(res)
|
|
699
|
+
return True, res
|
|
700
|
+
return False, None
|
|
701
|
+
|
|
702
|
+
def __call__(self, *args, **kwargs):
|
|
674
703
|
# Run in Graph mode.
|
|
675
|
-
if os.getenv("MS_JIT") != '0'
|
|
704
|
+
if context._get_mode() == context.GRAPH_MODE and os.getenv("MS_JIT") != '0':
|
|
705
|
+
if kwargs:
|
|
706
|
+
bound_arguments = self.sig.bind(*args, **kwargs)
|
|
707
|
+
bound_arguments.apply_defaults()
|
|
708
|
+
args = bound_arguments.args
|
|
709
|
+
kwargs = bound_arguments.kwargs
|
|
710
|
+
|
|
711
|
+
predict_compiled, res = self._predict(*args, **kwargs)
|
|
712
|
+
if predict_compiled:
|
|
713
|
+
return res
|
|
676
714
|
self._check_construct_args(*args)
|
|
715
|
+
|
|
677
716
|
if self._hook_fn_registered():
|
|
678
717
|
logger.warning(f"For 'Cell', it's not support hook function in graph mode. If you want to use hook "
|
|
679
718
|
f"function, please use context.set_context to set pynative mode.")
|
|
719
|
+
self._self_check()
|
|
680
720
|
out = self.compile_and_run(*args, **kwargs)
|
|
681
721
|
return out
|
|
682
722
|
|
|
683
723
|
# Run in PyNative mode.
|
|
684
|
-
if
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
# There many Casts in parameter_broadcast. Enable build faster.
|
|
688
|
-
self._do_parameter_broadcast()
|
|
724
|
+
if not (self._init_flag or self._is_check_and_refresh):
|
|
725
|
+
self._init_check()
|
|
726
|
+
self._self_check()
|
|
689
727
|
|
|
690
|
-
|
|
691
|
-
|
|
728
|
+
if not (self.requires_grad or self._dynamic_shape_inputs or self.mixed_precision_type):
|
|
729
|
+
if not (self._forward_pre_hook or self._forward_hook or self._backward_pre_hook or self._backward_hook or
|
|
730
|
+
self._shard_fn or self._recompute_cell or (self.has_bprop and _pynative_executor.requires_grad())):
|
|
731
|
+
return self.construct(*args, **kwargs)
|
|
692
732
|
|
|
693
|
-
|
|
694
|
-
_pynative_executor.set_grad_flag(True)
|
|
733
|
+
return self._run_construct(*args, **kwargs)
|
|
695
734
|
|
|
696
|
-
|
|
697
|
-
self._check_compile_dynamic_shape(self._dynamic_shape_inputs, args)
|
|
735
|
+
return self._complex_call(*args, **kwargs)
|
|
698
736
|
|
|
699
|
-
|
|
737
|
+
def _complex_call(self, *args, **kwargs):
|
|
738
|
+
"""
|
|
739
|
+
PyNative call with requires_grad or hooks
|
|
740
|
+
"""
|
|
741
|
+
self._call_pre_process(*args, **kwargs)
|
|
742
|
+
|
|
743
|
+
if not (self._forward_pre_hook or self._forward_hook or self._backward_pre_hook or self._backward_hook or
|
|
744
|
+
self._shard_fn or self._recompute_cell or self.has_bprop):
|
|
745
|
+
output = self.construct(*args, **kwargs)
|
|
746
|
+
else:
|
|
747
|
+
output = self._run_construct(*args, **kwargs)
|
|
748
|
+
|
|
749
|
+
self._call_post_process(output, *args, **kwargs)
|
|
750
|
+
|
|
751
|
+
return output
|
|
752
|
+
|
|
753
|
+
def _call_pre_process(self, *args, **kwargs):
|
|
754
|
+
"""
|
|
755
|
+
Process cell info before call construct
|
|
756
|
+
"""
|
|
757
|
+
if self.requires_grad:
|
|
758
|
+
_pynative_executor.set_grad_flag(True)
|
|
700
759
|
_pynative_executor.new_graph(self, *args, **kwargs)
|
|
701
|
-
|
|
760
|
+
elif self._dynamic_shape_inputs is not None:
|
|
761
|
+
_pynative_executor.set_cell_use_dynamic_shape_process(True)
|
|
762
|
+
|
|
763
|
+
# Set mixed precision
|
|
764
|
+
if self.mixed_precision_type is not None:
|
|
765
|
+
_pynative_executor.set_mixed_precision_type(self.mixed_precision_type)
|
|
766
|
+
|
|
767
|
+
def _call_post_process(self, output, *args, **kwargs):
|
|
768
|
+
"""
|
|
769
|
+
Process cell info after call construct
|
|
770
|
+
"""
|
|
771
|
+
if self.requires_grad:
|
|
702
772
|
_pynative_executor.end_graph(self, output, *args, **kwargs)
|
|
703
|
-
|
|
704
|
-
_pynative_executor.
|
|
705
|
-
raise err
|
|
773
|
+
elif self._dynamic_shape_inputs is not None:
|
|
774
|
+
_pynative_executor.set_cell_use_dynamic_shape_process(False)
|
|
706
775
|
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
776
|
+
# mixed precision reset
|
|
777
|
+
if self.mixed_precision_type is not None:
|
|
778
|
+
_pynative_executor.set_mixed_precision_type(MixedPrecisionType.NOTSET, False)
|
|
710
779
|
|
|
711
|
-
def
|
|
712
|
-
"""
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
780
|
+
def _call_custom_bprop(self, *args, **kwargs):
|
|
781
|
+
"""
|
|
782
|
+
Call custom bprop for cell bprop.
|
|
783
|
+
"""
|
|
784
|
+
with _no_grad():
|
|
785
|
+
output = self.construct(*args, **kwargs)
|
|
786
|
+
_pynative_executor.call_custom_bprop(self, output, *args, **kwargs)
|
|
787
|
+
return output
|
|
716
788
|
|
|
717
789
|
def _add_attr(self, name, value):
|
|
718
790
|
if name and name[:2] != '__' and name not in Cell.IGNORE_LIST:
|
|
@@ -830,15 +902,6 @@ class Cell(Cell_):
|
|
|
830
902
|
else:
|
|
831
903
|
self.insert_param_to_cell(name, None)
|
|
832
904
|
|
|
833
|
-
def _set_attr_for_tensor(self, name, value):
|
|
834
|
-
if context._get_mode() == context.PYNATIVE_MODE:
|
|
835
|
-
tensor_list = self.__dict__.get('_tensor_list')
|
|
836
|
-
if name in self.__dict__:
|
|
837
|
-
del self.__dict__[name]
|
|
838
|
-
tensor_list[name] = value
|
|
839
|
-
else:
|
|
840
|
-
object.__setattr__(self, name, value)
|
|
841
|
-
|
|
842
905
|
def __setattr__(self, name, value):
|
|
843
906
|
cells = self.__dict__.get('_cells')
|
|
844
907
|
params = self.__dict__.get('_params')
|
|
@@ -856,8 +919,6 @@ class Cell(Cell_):
|
|
|
856
919
|
if value is not None:
|
|
857
920
|
raise TypeError(f"For 'Cell', the type of {name} must be cell, but got {type(value).__name__}.")
|
|
858
921
|
self._cells[name] = None
|
|
859
|
-
elif isinstance(value, Tensor):
|
|
860
|
-
self._set_attr_for_tensor(name, value)
|
|
861
922
|
else:
|
|
862
923
|
if isinstance(value, Primitive):
|
|
863
924
|
value.set_prim_instance_name(name)
|
|
@@ -910,14 +971,25 @@ class Cell(Cell_):
|
|
|
910
971
|
"""
|
|
911
972
|
logger.warning("'set_parallel_input_with_inputs' function is deprecated.")
|
|
912
973
|
|
|
913
|
-
def set_inputs(self, *inputs):
|
|
974
|
+
def set_inputs(self, *inputs, **kwargs):
|
|
914
975
|
"""
|
|
915
976
|
Save set inputs for computation graph. The number of inputs should be the same with that of the datasets. When
|
|
916
977
|
using Model for dynamic shape, please make sure that all networks and loss functions passed to the Model are
|
|
917
|
-
configured with set_inputs. The
|
|
978
|
+
configured with set_inputs. The shape of input Tensor can be either dynamic or static.
|
|
979
|
+
|
|
980
|
+
.. note::
|
|
981
|
+
There are two mode:
|
|
982
|
+
|
|
983
|
+
- Full mode: arguments will be used as all compile inputs for graph-compiling.
|
|
984
|
+
- Incremental mode: arguments will set to some of the Cell inputs, which will be substituted into the input
|
|
985
|
+
at the corresponding position for graph-compiling.
|
|
986
|
+
|
|
987
|
+
Only one of inputs or kwargs can be set. Inputs for full mode and kwargs for incremental mode.
|
|
918
988
|
|
|
919
989
|
Args:
|
|
920
|
-
inputs (tuple):
|
|
990
|
+
inputs (tuple): Full mode arguments.
|
|
991
|
+
kwargs (dict): Incremental mode arguments. The acceptable key is the name of parameter defined
|
|
992
|
+
in `self.construct`.
|
|
921
993
|
|
|
922
994
|
.. warning::
|
|
923
995
|
This is an experimental API that is subject to change or deletion.
|
|
@@ -937,16 +1009,30 @@ class Cell(Cell_):
|
|
|
937
1009
|
>>> net = ReluNet()
|
|
938
1010
|
>>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32)
|
|
939
1011
|
>>> net.set_inputs(input_dyn)
|
|
940
|
-
>>>
|
|
941
|
-
>>> output = net(
|
|
1012
|
+
>>> input = Tensor(np.random.random([3, 10]), dtype=ms.float32)
|
|
1013
|
+
>>> output = net(input)
|
|
1014
|
+
>>>
|
|
1015
|
+
>>> net2 = ReluNet()
|
|
1016
|
+
>>> net2.set_inputs(x=input_dyn)
|
|
1017
|
+
>>> output = net2(input)
|
|
942
1018
|
"""
|
|
943
1019
|
if self.grad_ops_label:
|
|
944
1020
|
logger.warning(f'For Cell, set_inputs must be set before the gradient function of the network is '
|
|
945
1021
|
f'generated.')
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
1022
|
+
if kwargs and inputs:
|
|
1023
|
+
raise ValueError('For Cell, set_inputs should only set inputs or kwargs(inputs: %s, kwargs: %s)!'
|
|
1024
|
+
% (inputs, kwargs))
|
|
1025
|
+
|
|
1026
|
+
if not kwargs:
|
|
1027
|
+
self._dynamic_shape_inputs = inputs
|
|
1028
|
+
if context._get_mode() == context.PYNATIVE_MODE:
|
|
1029
|
+
_pynative_executor.set_dynamic_input(self, *self._dynamic_shape_inputs)
|
|
1030
|
+
else:
|
|
1031
|
+
self._check_construct_args(*inputs)
|
|
1032
|
+
# TODO(tronzhang): It may error for no actually args here. So just set in fullmode,
|
|
1033
|
+
# which means that incremental mode is lacking dynamic input.
|
|
1034
|
+
else:
|
|
1035
|
+
self._dynamic_shape_inputs = _process_dyn_args(self.construct, kwargs)
|
|
950
1036
|
|
|
951
1037
|
def get_inputs(self):
|
|
952
1038
|
"""
|
|
@@ -981,6 +1067,48 @@ class Cell(Cell_):
|
|
|
981
1067
|
|
|
982
1068
|
return self._dynamic_shape_inputs
|
|
983
1069
|
|
|
1070
|
+
def _check_parameter_consistency(self, set_inputs, net_inputs):
|
|
1071
|
+
"""Check consistency for parameter."""
|
|
1072
|
+
for index, (set_input, net_input) in enumerate(zip(set_inputs, net_inputs)):
|
|
1073
|
+
if isinstance(set_input, Tensor):
|
|
1074
|
+
if not isinstance(net_input, Tensor):
|
|
1075
|
+
raise TypeError(
|
|
1076
|
+
f"For 'set_inputs' and tuple(list) in 'set_inputs',the type of {index + 1}th input must "
|
|
1077
|
+
f"be Tensor, but got {type(net_input)}.")
|
|
1078
|
+
if isinstance(set_input, Parameter) != isinstance(net_input, Parameter):
|
|
1079
|
+
raise TypeError(
|
|
1080
|
+
f"For 'set_inputs' and tuple(list) in 'set_inputs', the {index + 1}th input must be the same "
|
|
1081
|
+
f"as expected, but got expected: {type(set_input)} and input: {type(net_input)}.")
|
|
1082
|
+
elif isinstance(set_input, (tuple, list)):
|
|
1083
|
+
if not isinstance(net_input, (tuple, list)):
|
|
1084
|
+
raise TypeError(
|
|
1085
|
+
f"The {index + 1}th input type of 'set_inputs' or tuple(list) in "
|
|
1086
|
+
f"'set_inputs' must be tuple or list, but got {type(net_input)}.")
|
|
1087
|
+
self._check_parameter_consistency(set_input, net_input)
|
|
1088
|
+
|
|
1089
|
+
def _get_compile_args(self, args):
|
|
1090
|
+
"""Get compile arguments."""
|
|
1091
|
+
# this is used only for test
|
|
1092
|
+
set_by_auto_dynamic = False
|
|
1093
|
+
if is_auto_dynamic():
|
|
1094
|
+
if self._dynamic_shape_inputs is None:
|
|
1095
|
+
set_by_auto_dynamic = True
|
|
1096
|
+
else:
|
|
1097
|
+
if isinstance(self._dynamic_shape_inputs, (list, tuple)) and self._dynamic_shape_inputs[0] is None:
|
|
1098
|
+
set_by_auto_dynamic = True
|
|
1099
|
+
if set_by_auto_dynamic:
|
|
1100
|
+
self._dynamic_shape_inputs = convert_inputs_to_dynamic(*args)
|
|
1101
|
+
|
|
1102
|
+
if self._dynamic_shape_inputs is not None:
|
|
1103
|
+
logger.debug("Compiled Graph with dynamic shape")
|
|
1104
|
+
compile_args = _generate_dyn_compile_args(args, self._dynamic_shape_inputs)
|
|
1105
|
+
_cell_graph_executor._graph_executor.check_argument_consistency(compile_args, args, "set_inputs")
|
|
1106
|
+
self._check_parameter_consistency(compile_args, args)
|
|
1107
|
+
Validator.check_symbolic_shape(compile_args, args)
|
|
1108
|
+
self.saved_dynamic_shape = compile_args
|
|
1109
|
+
return compile_args
|
|
1110
|
+
return args
|
|
1111
|
+
|
|
984
1112
|
def compile(self, *args, **kwargs):
|
|
985
1113
|
"""
|
|
986
1114
|
Compile Cell as a computation graph, the input must be consistent with the input defined in construct.
|
|
@@ -989,19 +1117,9 @@ class Cell(Cell_):
|
|
|
989
1117
|
args (tuple): Args of the Cell object.
|
|
990
1118
|
kwargs (dict): Kwargs of the Cell object.
|
|
991
1119
|
"""
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
if self._dynamic_shape_inputs is None:
|
|
997
|
-
_cell_graph_executor.compile(self, phase=self.phase,
|
|
998
|
-
jit_config_dict=self._jit_config_dict, *args, **kwargs)
|
|
999
|
-
else:
|
|
1000
|
-
self._check_compile_dynamic_shape(self._dynamic_shape_inputs, args)
|
|
1001
|
-
self.saved_dynamic_shape = self._dynamic_shape_inputs
|
|
1002
|
-
_cell_graph_executor.compile(self, *self._dynamic_shape_inputs, phase=self.phase,
|
|
1003
|
-
jit_config_dict=self._jit_config_dict, **kwargs)
|
|
1004
|
-
logger.debug("Compiled Graph with dynamic shape")
|
|
1120
|
+
self._compile_args = self._get_compile_args(args)
|
|
1121
|
+
_cell_graph_executor.compile(self, *self._compile_args, phase=self.phase,
|
|
1122
|
+
jit_config_dict=self._jit_config_dict, **kwargs)
|
|
1005
1123
|
|
|
1006
1124
|
def compile_and_run(self, *args, **kwargs):
|
|
1007
1125
|
"""
|
|
@@ -1019,7 +1137,7 @@ class Cell(Cell_):
|
|
|
1019
1137
|
"""
|
|
1020
1138
|
self.compile(*args, **kwargs)
|
|
1021
1139
|
self.add_flags(ge_sync_data=False)
|
|
1022
|
-
new_args = _get_args_for_run(self, args, kwargs)
|
|
1140
|
+
new_args = _get_args_for_run(self, args, kwargs, self._compile_args)
|
|
1023
1141
|
return _cell_graph_executor(self, *new_args, phase=self.phase)
|
|
1024
1142
|
|
|
1025
1143
|
def auto_parallel_compile_and_run(self):
|
|
@@ -1033,6 +1151,7 @@ class Cell(Cell_):
|
|
|
1033
1151
|
|
|
1034
1152
|
def exec_checkpoint_graph(self):
|
|
1035
1153
|
"""Executes GE saving checkpoint graph operation."""
|
|
1154
|
+
logger.warning("'exec_checkpoint_graph' function is deprecated.")
|
|
1036
1155
|
self.add_flags(ge_sync_data=True)
|
|
1037
1156
|
_cell_graph_executor(self, phase='save')
|
|
1038
1157
|
|
|
@@ -1070,14 +1189,14 @@ class Cell(Cell_):
|
|
|
1070
1189
|
Parameter(name=bias, shape=(3,), dtype=Int64, requires_grad=True)
|
|
1071
1190
|
"""
|
|
1072
1191
|
if not param_name:
|
|
1073
|
-
raise KeyError("For 'insert_param_to_cell', the argument 'param_name' should not be None.")
|
|
1192
|
+
raise KeyError(f"For 'insert_param_to_cell', the argument 'param_name' should not be None.")
|
|
1074
1193
|
if check_name_contain_dot and '.' in param_name:
|
|
1075
|
-
raise KeyError("For 'insert_param_to_cell', the argument 'param_name' should not contain
|
|
1194
|
+
raise KeyError(f"For 'insert_param_to_cell', the argument 'param_name' should not contain'.' ")
|
|
1076
1195
|
if '_params' not in self.__dict__:
|
|
1077
|
-
raise AttributeError("For 'insert_param_to_cell', please call Cell.__init__() firstly.")
|
|
1196
|
+
raise AttributeError(f"For 'insert_param_to_cell', please call Cell.__init__() firstly.")
|
|
1078
1197
|
if hasattr(self, param_name) and param_name not in self._params:
|
|
1079
|
-
raise KeyError("For 'insert_param_to_cell', the {} parameter already exists in the network.
|
|
1080
|
-
"insert another parameter with the same name."
|
|
1198
|
+
raise KeyError(f"For 'insert_param_to_cell', the {param_name} parameter already exists in the network."
|
|
1199
|
+
f"Cannot insert another parameter with the same name.")
|
|
1081
1200
|
if not isinstance(param, Parameter) and param is not None:
|
|
1082
1201
|
raise TypeError(f"For 'insert_param_to_cell', the argument 'param' must be 'Parameter' if not None, "
|
|
1083
1202
|
f"but got {type(param)}.")
|
|
@@ -1139,11 +1258,11 @@ class Cell(Cell_):
|
|
|
1139
1258
|
raise TypeError(f"For 'insert_child_to_cell', the type of parameter 'child_name' must be str, "
|
|
1140
1259
|
f"but got {type(child_name)}.")
|
|
1141
1260
|
if not child_name or '.' in child_name:
|
|
1142
|
-
raise KeyError("For 'insert_child_to_cell', the parameter 'child_name' can not be None and "
|
|
1143
|
-
"can not contain '.'")
|
|
1261
|
+
raise KeyError(f"For 'insert_child_to_cell', the parameter 'child_name' can not be None and "
|
|
1262
|
+
"can not contain '.' ")
|
|
1144
1263
|
if hasattr(self, child_name) and child_name not in self._cells:
|
|
1145
|
-
raise KeyError("For 'insert_child_to_cell', the {} child cell already exists in the network.
|
|
1146
|
-
"insert another child cell with the same name."
|
|
1264
|
+
raise KeyError(f"For 'insert_child_to_cell', the {child_name} child cell already exists in the network."
|
|
1265
|
+
f"Cannot insert another child cell with the same name.")
|
|
1147
1266
|
if not isinstance(child_cell, Cell) and child_cell is not None:
|
|
1148
1267
|
raise TypeError(f"For 'insert_child_to_cell', the argument 'child_cell' must be 'Cell' if not None, "
|
|
1149
1268
|
f"but got type {type(child_cell)}.")
|
|
@@ -1163,7 +1282,7 @@ class Cell(Cell_):
|
|
|
1163
1282
|
Returns:
|
|
1164
1283
|
Tensor, returns the computed result.
|
|
1165
1284
|
"""
|
|
1166
|
-
|
|
1285
|
+
raise AttributeError("For 'Cell', the method 'construct' is not defined.")
|
|
1167
1286
|
|
|
1168
1287
|
def remove_redundant_parameters(self):
|
|
1169
1288
|
"""
|
|
@@ -1361,7 +1480,7 @@ class Cell(Cell_):
|
|
|
1361
1480
|
|
|
1362
1481
|
Tutorial Examples:
|
|
1363
1482
|
- `Model Training - Optimizer
|
|
1364
|
-
<https://mindspore.cn/tutorials/en/
|
|
1483
|
+
<https://mindspore.cn/tutorials/en/master/beginner/train.html#optimizer>`_
|
|
1365
1484
|
"""
|
|
1366
1485
|
return list(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse)))
|
|
1367
1486
|
|
|
@@ -1472,7 +1591,7 @@ class Cell(Cell_):
|
|
|
1472
1591
|
|
|
1473
1592
|
Tutorial Examples:
|
|
1474
1593
|
- `Building a Network - Model Parameters
|
|
1475
|
-
<https://mindspore.cn/tutorials/en/
|
|
1594
|
+
<https://mindspore.cn/tutorials/en/master/beginner/model.html#model-parameters>`_
|
|
1476
1595
|
"""
|
|
1477
1596
|
cells = []
|
|
1478
1597
|
if expand:
|
|
@@ -1630,10 +1749,13 @@ class Cell(Cell_):
|
|
|
1630
1749
|
def _add_mixed_precision_flag(self, **flags):
|
|
1631
1750
|
"""Add mixed precision flag to current cell"""
|
|
1632
1751
|
if "fp16" in flags and flags.get("fp16", False):
|
|
1752
|
+
self.mixed_precision_type = MixedPrecisionType.FP16
|
|
1633
1753
|
Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP16)
|
|
1634
1754
|
if "fp32" in flags and flags.get("fp32", False):
|
|
1755
|
+
self.mixed_precision_type = MixedPrecisionType.FP32
|
|
1635
1756
|
Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP32)
|
|
1636
1757
|
if "bf16" in flags and flags.get("bf16", False):
|
|
1758
|
+
self.mixed_precision_type = MixedPrecisionType.BF16
|
|
1637
1759
|
Cell_.set_mixed_precision_type(self, MixedPrecisionType.BF16)
|
|
1638
1760
|
|
|
1639
1761
|
def apply(self, fn):
|
|
@@ -1698,6 +1820,9 @@ class Cell(Cell_):
|
|
|
1698
1820
|
if not hasattr(self, "_func_graph_flags"):
|
|
1699
1821
|
self._func_graph_flags = {}
|
|
1700
1822
|
self._func_graph_flags.update({**flags})
|
|
1823
|
+
if context._get_mode() == context.PYNATIVE_MODE and self._func_graph_flags.get("output_no_recompute"):
|
|
1824
|
+
raise TypeError("Recompute is not supported in PyNative mode currently, you can use "
|
|
1825
|
+
"'context.set_context(mode=context.GRAPH_MODE)' or @jit to set graph mode.")
|
|
1701
1826
|
self.__dict__.update({**flags})
|
|
1702
1827
|
self._add_mixed_precision_flag(**flags)
|
|
1703
1828
|
return self
|
|
@@ -1808,7 +1933,7 @@ class Cell(Cell_):
|
|
|
1808
1933
|
accelerate the algorithm in the algorithm library.
|
|
1809
1934
|
|
|
1810
1935
|
If `boost_type` is not in the algorithm library, please view the algorithm in the algorithm library through
|
|
1811
|
-
`algorithm library <https://gitee.com/mindspore/mindspore/tree/
|
|
1936
|
+
`algorithm library <https://gitee.com/mindspore/mindspore/tree/master/mindspore/python/mindspore/boost>`_.
|
|
1812
1937
|
|
|
1813
1938
|
Note:
|
|
1814
1939
|
Some acceleration algorithms may affect the accuracy of the network, please choose carefully.
|
|
@@ -1865,7 +1990,7 @@ class Cell(Cell_):
|
|
|
1865
1990
|
|
|
1866
1991
|
Tutorial Examples:
|
|
1867
1992
|
- `Model Training - Implementing Training and Evaluation
|
|
1868
|
-
<https://mindspore.cn/tutorials/en/
|
|
1993
|
+
<https://mindspore.cn/tutorials/en/master/beginner/train.html#training-and-evaluation>`_
|
|
1869
1994
|
"""
|
|
1870
1995
|
if mode:
|
|
1871
1996
|
self._phase = 'train'
|
|
@@ -1945,11 +2070,11 @@ class Cell(Cell_):
|
|
|
1945
2070
|
Note:
|
|
1946
2071
|
- The `register_forward_pre_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
|
|
1947
2072
|
- 'hook_fn' must be defined as the following code.
|
|
1948
|
-
`
|
|
2073
|
+
`cell` is the object of registered Cell. `inputs` is the forward
|
|
1949
2074
|
input objects passed to the Cell. The 'hook_fn' can modify the forward input objects by returning new
|
|
1950
2075
|
forward input objects.
|
|
1951
2076
|
- It should have the following signature:
|
|
1952
|
-
hook_fn(
|
|
2077
|
+
hook_fn(cell, inputs) -> new input objects or none.
|
|
1953
2078
|
- In order to prevent running failed when switching to graph mode, it is not recommended to write it in the
|
|
1954
2079
|
`construct` function of Cell object. In the pynative mode, if the `register_forward_pre_hook` function is
|
|
1955
2080
|
called in the `construct` function of the Cell object, a hook function will be added at each run time of
|
|
@@ -1959,8 +2084,8 @@ class Cell(Cell_):
|
|
|
1959
2084
|
hook_fn (function): Python function. Forward pre hook function.
|
|
1960
2085
|
|
|
1961
2086
|
Returns:
|
|
1962
|
-
|
|
1963
|
-
|
|
2087
|
+
A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
|
|
2088
|
+
`handle.remove()` .
|
|
1964
2089
|
|
|
1965
2090
|
Raises:
|
|
1966
2091
|
TypeError: If the `hook_fn` is not a function of python.
|
|
@@ -1973,7 +2098,7 @@ class Cell(Cell_):
|
|
|
1973
2098
|
>>> import mindspore as ms
|
|
1974
2099
|
>>> from mindspore import Tensor, nn, ops
|
|
1975
2100
|
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
|
|
1976
|
-
>>> def forward_pre_hook_fn(
|
|
2101
|
+
>>> def forward_pre_hook_fn(cell, inputs):
|
|
1977
2102
|
... print("forward inputs: ", inputs)
|
|
1978
2103
|
...
|
|
1979
2104
|
>>> class Net(nn.Cell):
|
|
@@ -1995,24 +2120,12 @@ class Cell(Cell_):
|
|
|
1995
2120
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
|
|
1996
2121
|
value= [ 2.00000000e+00]))
|
|
1997
2122
|
"""
|
|
1998
|
-
if context.
|
|
1999
|
-
logger.warning(f"'register_forward_pre_hook' function is only supported in pynative mode, you can use "
|
|
2000
|
-
f"context.set_context to set pynative mode.")
|
|
2123
|
+
if context._get_mode() == context.GRAPH_MODE:
|
|
2001
2124
|
return HookHandle()
|
|
2002
|
-
|
|
2003
|
-
|
|
2004
|
-
|
|
2005
|
-
|
|
2006
|
-
if hook_fn.__code__.co_name == "staging_specialize":
|
|
2007
|
-
raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@jit' is not supported.")
|
|
2008
|
-
|
|
2009
|
-
self._enable_forward_pre_hook = True
|
|
2010
|
-
_pynative_executor.set_hook_changed(self)
|
|
2011
|
-
if not hasattr(self, '_forward_pre_hook_key'):
|
|
2012
|
-
self._forward_pre_hook_key = -1
|
|
2013
|
-
self._forward_pre_hook_key += 1
|
|
2014
|
-
self._forward_pre_hook[self._forward_pre_hook_key] = hook_fn
|
|
2015
|
-
handle = HookHandle(self, self._forward_pre_hook_key, "_forward_pre_hook")
|
|
2125
|
+
if not check_hook_fn("register_forward_pre_hook", hook_fn):
|
|
2126
|
+
return HookHandle()
|
|
2127
|
+
handle = HookHandle(self._forward_pre_hook)
|
|
2128
|
+
self._forward_pre_hook[handle.handle_id] = hook_fn
|
|
2016
2129
|
return handle
|
|
2017
2130
|
|
|
2018
2131
|
def _run_forward_pre_hook(self, inputs):
|
|
@@ -2028,15 +2141,23 @@ class Cell(Cell_):
|
|
|
2028
2141
|
Supported Platforms:
|
|
2029
2142
|
``Ascend`` ``GPU`` ``CPU``
|
|
2030
2143
|
"""
|
|
2031
|
-
|
|
2144
|
+
forward_pre_hook_inputs = inputs
|
|
2032
2145
|
for fn in self._forward_pre_hook.values():
|
|
2033
|
-
ret = fn(
|
|
2146
|
+
ret = fn(self, forward_pre_hook_inputs)
|
|
2034
2147
|
if ret is not None:
|
|
2035
2148
|
if not isinstance(ret, tuple):
|
|
2036
|
-
|
|
2149
|
+
forward_pre_hook_inputs = (ret,)
|
|
2037
2150
|
else:
|
|
2038
|
-
|
|
2039
|
-
|
|
2151
|
+
forward_pre_hook_inputs = ret
|
|
2152
|
+
|
|
2153
|
+
if isinstance(inputs, tuple):
|
|
2154
|
+
if not isinstance(forward_pre_hook_inputs, tuple):
|
|
2155
|
+
forward_pre_hook_inputs = (forward_pre_hook_inputs,)
|
|
2156
|
+
if len(forward_pre_hook_inputs) != len(inputs):
|
|
2157
|
+
raise TypeError(
|
|
2158
|
+
"The forward pre hook return value size is {} not equal to input size {}".format(
|
|
2159
|
+
len(forward_pre_hook_inputs), len(inputs)))
|
|
2160
|
+
return forward_pre_hook_inputs
|
|
2040
2161
|
|
|
2041
2162
|
def register_forward_hook(self, hook_fn):
|
|
2042
2163
|
"""
|
|
@@ -2045,11 +2166,11 @@ class Cell(Cell_):
|
|
|
2045
2166
|
Note:
|
|
2046
2167
|
- The `register_forward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
|
|
2047
2168
|
- 'hook_fn' must be defined as the following code.
|
|
2048
|
-
`
|
|
2169
|
+
`cell` is the object of registered Cell. `inputs` is the forward
|
|
2049
2170
|
input objects passed to the Cell. `output` is the forward output object of the Cell. The 'hook_fn' can
|
|
2050
2171
|
modify the forward output object by returning new forward output object.
|
|
2051
2172
|
- It should have the following signature:
|
|
2052
|
-
hook_fn(
|
|
2173
|
+
hook_fn(cell, inputs, output) -> new output object or none.
|
|
2053
2174
|
- In order to prevent running failed when switching to graph mode, it is not recommended to write it in the
|
|
2054
2175
|
`construct` function of Cell object. In the pynative mode, if the `register_forward_hook` function is
|
|
2055
2176
|
called in the `construct` function of the Cell object, a hook function will be added at each run time of
|
|
@@ -2059,8 +2180,8 @@ class Cell(Cell_):
|
|
|
2059
2180
|
hook_fn (function): Python function. Forward hook function.
|
|
2060
2181
|
|
|
2061
2182
|
Returns:
|
|
2062
|
-
|
|
2063
|
-
|
|
2183
|
+
A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
|
|
2184
|
+
`handle.remove()` .
|
|
2064
2185
|
|
|
2065
2186
|
Raises:
|
|
2066
2187
|
TypeError: If the `hook_fn` is not a function of python.
|
|
@@ -2073,7 +2194,7 @@ class Cell(Cell_):
|
|
|
2073
2194
|
>>> import mindspore as ms
|
|
2074
2195
|
>>> from mindspore import Tensor, nn, ops
|
|
2075
2196
|
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
|
|
2076
|
-
>>> def forward_hook_fn(
|
|
2197
|
+
>>> def forward_hook_fn(cell, inputs, output):
|
|
2077
2198
|
... print("forward inputs: ", inputs)
|
|
2078
2199
|
... print("forward output: ", output)
|
|
2079
2200
|
...
|
|
@@ -2097,24 +2218,12 @@ class Cell(Cell_):
|
|
|
2097
2218
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
|
|
2098
2219
|
value= [ 2.00000000e+00]))
|
|
2099
2220
|
"""
|
|
2100
|
-
if context.
|
|
2101
|
-
logger.warning(f"'register_forward_hook' function is only supported in pynative mode, you can use "
|
|
2102
|
-
f"context.set_context to set pynative mode.")
|
|
2221
|
+
if context._get_mode() == context.GRAPH_MODE:
|
|
2103
2222
|
return HookHandle()
|
|
2104
|
-
|
|
2105
|
-
|
|
2106
|
-
|
|
2107
|
-
|
|
2108
|
-
if hook_fn.__code__.co_name == "staging_specialize":
|
|
2109
|
-
raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@jit' is not supported.")
|
|
2110
|
-
|
|
2111
|
-
self._enable_forward_hook = True
|
|
2112
|
-
_pynative_executor.set_hook_changed(self)
|
|
2113
|
-
if not hasattr(self, '_forward_hook_key'):
|
|
2114
|
-
self._forward_hook_key = -1
|
|
2115
|
-
self._forward_hook_key += 1
|
|
2116
|
-
self._forward_hook[self._forward_hook_key] = hook_fn
|
|
2117
|
-
handle = HookHandle(self, self._forward_hook_key, "_forward_hook")
|
|
2223
|
+
if not check_hook_fn("register_forward_hook", hook_fn):
|
|
2224
|
+
return HookHandle()
|
|
2225
|
+
handle = HookHandle(self._forward_hook)
|
|
2226
|
+
self._forward_hook[handle.handle_id] = hook_fn
|
|
2118
2227
|
return handle
|
|
2119
2228
|
|
|
2120
2229
|
def _run_forward_hook(self, inputs, output):
|
|
@@ -2131,12 +2240,110 @@ class Cell(Cell_):
|
|
|
2131
2240
|
Supported Platforms:
|
|
2132
2241
|
``Ascend`` ``GPU`` ``CPU``
|
|
2133
2242
|
"""
|
|
2134
|
-
|
|
2243
|
+
forward_hook_output = output
|
|
2135
2244
|
for fn in self._forward_hook.values():
|
|
2136
|
-
ret = fn(
|
|
2245
|
+
ret = fn(self, inputs, forward_hook_output)
|
|
2137
2246
|
if ret is not None:
|
|
2138
|
-
|
|
2139
|
-
|
|
2247
|
+
forward_hook_output = ret
|
|
2248
|
+
|
|
2249
|
+
if isinstance(output, tuple):
|
|
2250
|
+
if not isinstance(forward_hook_output, tuple):
|
|
2251
|
+
forward_hook_output = (forward_hook_output,)
|
|
2252
|
+
if len(forward_hook_output) != len(output):
|
|
2253
|
+
raise TypeError(
|
|
2254
|
+
"The forward hook return value size is {} not equal to output size {}".format(
|
|
2255
|
+
len(forward_hook_output), len(output)))
|
|
2256
|
+
return forward_hook_output
|
|
2257
|
+
|
|
2258
|
+
def register_backward_pre_hook(self, hook_fn):
|
|
2259
|
+
"""
|
|
2260
|
+
Register the backward pre hook function.
|
|
2261
|
+
|
|
2262
|
+
Note:
|
|
2263
|
+
- The `register_backward_pre_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
|
|
2264
|
+
- The 'hook_fn' must be defined as the following code.
|
|
2265
|
+
`cell` is the Cell object. `grad_output` is the gradient passed to the Cell.
|
|
2266
|
+
- The 'hook_fn' should have the following signature:
|
|
2267
|
+
hook_fn(cell, grad_output) -> New grad_output gradient or None.
|
|
2268
|
+
- The 'hook_fn' is executed in the python environment. In order to prevent running failed when switching to
|
|
2269
|
+
graph mode, it is not recommended to write it in the `construct` function of Cell object.
|
|
2270
|
+
- In the pynative
|
|
2271
|
+
mode, if the `register_backward_pre_hook` function is called in the `construct` function of the Cell
|
|
2272
|
+
object, a hook function will be added at each run time of Cell object.
|
|
2273
|
+
|
|
2274
|
+
Args:
|
|
2275
|
+
hook_fn (function): Python function. Backward pre hook function.
|
|
2276
|
+
|
|
2277
|
+
Returns:
|
|
2278
|
+
A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
|
|
2279
|
+
`handle.remove()` .
|
|
2280
|
+
|
|
2281
|
+
Raises:
|
|
2282
|
+
TypeError: If the `hook_fn` is not a function of python.
|
|
2283
|
+
|
|
2284
|
+
Supported Platforms:
|
|
2285
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2286
|
+
|
|
2287
|
+
Examples:
|
|
2288
|
+
>>> import numpy as np
|
|
2289
|
+
>>> import mindspore as ms
|
|
2290
|
+
>>> from mindspore import Tensor, nn, ops
|
|
2291
|
+
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
|
|
2292
|
+
>>> def backward_pre_hook_fn(cell, grad_output):
|
|
2293
|
+
... print("backward input: ", grad_output)
|
|
2294
|
+
...
|
|
2295
|
+
>>> class Net(nn.Cell):
|
|
2296
|
+
... def __init__(self):
|
|
2297
|
+
... super(Net, self).__init__()
|
|
2298
|
+
... self.relu = nn.ReLU()
|
|
2299
|
+
... self.handle = self.relu.register_backward_pre_hook(backward_pre_hook_fn)
|
|
2300
|
+
...
|
|
2301
|
+
... def construct(self, x):
|
|
2302
|
+
... x = x + x
|
|
2303
|
+
... x = self.relu(x)
|
|
2304
|
+
... return x
|
|
2305
|
+
>>> grad = ops.GradOperation(get_all=True)
|
|
2306
|
+
>>> net = Net()
|
|
2307
|
+
>>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)))
|
|
2308
|
+
backward input: (Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]),)
|
|
2309
|
+
>>> print(output)
|
|
2310
|
+
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
|
|
2311
|
+
"""
|
|
2312
|
+
if context._get_mode() == context.GRAPH_MODE:
|
|
2313
|
+
return HookHandle()
|
|
2314
|
+
if not check_hook_fn("register_backward_pre_hook", hook_fn):
|
|
2315
|
+
return HookHandle()
|
|
2316
|
+
handle = HookHandle(self._backward_pre_hook)
|
|
2317
|
+
self._backward_pre_hook[handle.handle_id] = hook_fn
|
|
2318
|
+
if self._cell_backward_pre_hook is None:
|
|
2319
|
+
# Generate a CellBackwardHook prim, and add function for it
|
|
2320
|
+
self._cell_backward_pre_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")",
|
|
2321
|
+
self, self._backward_pre_hook)
|
|
2322
|
+
self._cell_backward_pre_hook.register_backward_pre_hook()
|
|
2323
|
+
return handle
|
|
2324
|
+
|
|
2325
|
+
def _run_backward_pre_hook(self, outputs):
|
|
2326
|
+
"""
|
|
2327
|
+
Running backward pre hook function registered on Cell object.
|
|
2328
|
+
|
|
2329
|
+
Args:
|
|
2330
|
+
outputs: The output objects of cell object.
|
|
2331
|
+
|
|
2332
|
+
Returns:
|
|
2333
|
+
- **outputs** - New backward gradient or None.
|
|
2334
|
+
|
|
2335
|
+
Supported Platforms:
|
|
2336
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2337
|
+
"""
|
|
2338
|
+
ret = self._cell_backward_pre_hook(outputs)
|
|
2339
|
+
if isinstance(outputs, tuple):
|
|
2340
|
+
if not isinstance(ret, tuple):
|
|
2341
|
+
ret = (ret,)
|
|
2342
|
+
if len(ret) != len(outputs):
|
|
2343
|
+
raise TypeError(
|
|
2344
|
+
"The backward pre hook return value size is {} not equal to output size {}".format(
|
|
2345
|
+
len(ret), len(outputs)))
|
|
2346
|
+
return ret
|
|
2140
2347
|
|
|
2141
2348
|
def register_backward_hook(self, hook_fn):
|
|
2142
2349
|
"""
|
|
@@ -2145,11 +2352,11 @@ class Cell(Cell_):
|
|
|
2145
2352
|
Note:
|
|
2146
2353
|
- The `register_backward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
|
|
2147
2354
|
- The 'hook_fn' must be defined as the following code.
|
|
2148
|
-
`
|
|
2149
|
-
|
|
2150
|
-
|
|
2355
|
+
`cell` is the registered Cell object. `grad_input` is the gradient computed and passed to
|
|
2356
|
+
the next Cell or primitive, which can be return a new gradient or None. `grad_output` is the gradient
|
|
2357
|
+
passed to the Cell.
|
|
2151
2358
|
- The 'hook_fn' should have the following signature:
|
|
2152
|
-
hook_fn(
|
|
2359
|
+
hook_fn(cell, grad_input, grad_output) -> New grad_input gradient or none.
|
|
2153
2360
|
- The 'hook_fn' is executed in the python environment. In order to prevent running failed when switching to
|
|
2154
2361
|
graph mode, it is not recommended to write it in the `construct` function of Cell object. In the pynative
|
|
2155
2362
|
mode, if the `register_backward_hook` function is called in the `construct` function of the Cell object,
|
|
@@ -2159,8 +2366,8 @@ class Cell(Cell_):
|
|
|
2159
2366
|
hook_fn (function): Python function. Backward hook function.
|
|
2160
2367
|
|
|
2161
2368
|
Returns:
|
|
2162
|
-
|
|
2163
|
-
|
|
2369
|
+
A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
|
|
2370
|
+
`handle.remove()` .
|
|
2164
2371
|
|
|
2165
2372
|
Raises:
|
|
2166
2373
|
TypeError: If the `hook_fn` is not a function of python.
|
|
@@ -2173,9 +2380,9 @@ class Cell(Cell_):
|
|
|
2173
2380
|
>>> import mindspore as ms
|
|
2174
2381
|
>>> from mindspore import Tensor, nn, ops
|
|
2175
2382
|
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
|
|
2176
|
-
>>> def backward_hook_fn(
|
|
2177
|
-
... print("backward input: ",
|
|
2178
|
-
... print("backward output: ",
|
|
2383
|
+
>>> def backward_hook_fn(cell, grad_input, grad_output):
|
|
2384
|
+
... print("backward input: ", grad_output)
|
|
2385
|
+
... print("backward output: ", grad_input)
|
|
2179
2386
|
...
|
|
2180
2387
|
>>> class Net(nn.Cell):
|
|
2181
2388
|
... def __init__(self):
|
|
@@ -2195,22 +2402,17 @@ class Cell(Cell_):
|
|
|
2195
2402
|
>>> print(output)
|
|
2196
2403
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
|
|
2197
2404
|
"""
|
|
2198
|
-
if context.
|
|
2199
|
-
logger.warning(f"'register_backward_hook' function is only supported in pynative mode, you can use "
|
|
2200
|
-
f"context.set_context to set pynative mode.")
|
|
2405
|
+
if context._get_mode() == context.GRAPH_MODE:
|
|
2201
2406
|
return HookHandle()
|
|
2202
|
-
|
|
2203
|
-
|
|
2204
|
-
|
|
2205
|
-
|
|
2407
|
+
if not check_hook_fn("register_backward_hook", hook_fn):
|
|
2408
|
+
return HookHandle()
|
|
2409
|
+
handle = HookHandle(self._backward_hook)
|
|
2410
|
+
self._backward_hook[handle.handle_id] = hook_fn
|
|
2206
2411
|
if self._cell_backward_hook is None:
|
|
2207
|
-
|
|
2208
|
-
self._cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")"
|
|
2209
|
-
|
|
2210
|
-
|
|
2211
|
-
else:
|
|
2212
|
-
backward_hook_key = self._cell_backward_hook.register_backward_hook(hook_fn)
|
|
2213
|
-
handle = HookHandle(self, backward_hook_key, "_cell_backward_hook")
|
|
2412
|
+
# Generate a CellBackwardHook prim, and add function for it
|
|
2413
|
+
self._cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")",
|
|
2414
|
+
self, self._backward_hook)
|
|
2415
|
+
self._cell_backward_hook.register_backward_hook()
|
|
2214
2416
|
return handle
|
|
2215
2417
|
|
|
2216
2418
|
def _backward_hook_construct(self, *inputs, **kwargs):
|
|
@@ -2227,15 +2429,31 @@ class Cell(Cell_):
|
|
|
2227
2429
|
Supported Platforms:
|
|
2228
2430
|
``Ascend`` ``GPU`` ``CPU``
|
|
2229
2431
|
"""
|
|
2230
|
-
|
|
2231
|
-
|
|
2232
|
-
|
|
2233
|
-
|
|
2234
|
-
|
|
2235
|
-
|
|
2236
|
-
|
|
2432
|
+
# cell_backward_hook has CellBackwardHook op, so keep input args as they are.
|
|
2433
|
+
outputs = self._cell_backward_hook(*inputs)
|
|
2434
|
+
# If the inputs have more than two args, the outputs will also have more than two args and will be wrapped into
|
|
2435
|
+
# a tuple, so need to do unwrapping. If inputs is empty, we also need to unwrap it.
|
|
2436
|
+
# Because when output of runop method is one, it will not wrap a tuple, we need not unwrap it.
|
|
2437
|
+
is_need_unwrap = False
|
|
2438
|
+
if isinstance(outputs, tuple) and len(inputs) != 1:
|
|
2439
|
+
is_need_unwrap = True
|
|
2440
|
+
|
|
2441
|
+
if self._recompute_cell is not None:
|
|
2442
|
+
if is_need_unwrap:
|
|
2443
|
+
outputs = self._recompute_cell(*outputs, **kwargs)
|
|
2444
|
+
else:
|
|
2445
|
+
outputs = self._recompute_cell(outputs, **kwargs)
|
|
2446
|
+
elif self.has_bprop:
|
|
2447
|
+
if is_need_unwrap:
|
|
2448
|
+
outputs = self._call_custom_bprop(*outputs, **kwargs)
|
|
2449
|
+
else:
|
|
2450
|
+
outputs = self._call_custom_bprop(outputs, **kwargs)
|
|
2237
2451
|
else:
|
|
2238
|
-
|
|
2452
|
+
if is_need_unwrap:
|
|
2453
|
+
outputs = self.construct(*outputs, **kwargs)
|
|
2454
|
+
else:
|
|
2455
|
+
outputs = self.construct(outputs, **kwargs)
|
|
2456
|
+
|
|
2239
2457
|
outputs = self._cell_backward_hook(outputs)
|
|
2240
2458
|
return outputs
|
|
2241
2459
|
|
|
@@ -2365,6 +2583,9 @@ class Cell(Cell_):
|
|
|
2365
2583
|
introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode.
|
|
2366
2584
|
Default: ``False`` .
|
|
2367
2585
|
"""
|
|
2586
|
+
if context.get_context("mode") == context.PYNATIVE_MODE:
|
|
2587
|
+
self._recompute_cell = recompute_registry.get()(self.construct)
|
|
2588
|
+
return
|
|
2368
2589
|
self._recompute()
|
|
2369
2590
|
if 'mp_comm_recompute' in kwargs.keys():
|
|
2370
2591
|
self._mp_comm_recompute(kwargs.get('mp_comm_recompute', False))
|
|
@@ -2384,16 +2605,13 @@ class Cell(Cell_):
|
|
|
2384
2605
|
"the key kwargs must be 'mp_comm_recompute', "
|
|
2385
2606
|
"'parallel_optimizer_comm_recompute', 'recompute_slice_activation'" % key)
|
|
2386
2607
|
|
|
2608
|
+
@deprecated("2.3", "infer_param_pipeline_stage")
|
|
2387
2609
|
def infer_param_pipeline_stage(self):
|
|
2388
2610
|
"""
|
|
2389
2611
|
Infer pipeline stages of all parameters in the cell.
|
|
2390
2612
|
|
|
2391
2613
|
Note:
|
|
2392
|
-
-
|
|
2393
|
-
the parameter should use add_pipeline_stage to add it's pipeline_stage information.
|
|
2394
|
-
- If a parameter P has been used by two operators in different stages "stageA" and "stageB",
|
|
2395
|
-
the parameter P should use P.add_pipeline_stage(stageA) and P.add_pipeline_stage(stageB)
|
|
2396
|
-
to add it's stage information before using infer_param_pipeline_stage.
|
|
2614
|
+
- The interface is deprecated from version 2.3 and will be removed in a future version.
|
|
2397
2615
|
|
|
2398
2616
|
Returns:
|
|
2399
2617
|
The params belong to current stage in pipeline parallel.
|
|
@@ -2448,86 +2666,6 @@ class Cell(Cell_):
|
|
|
2448
2666
|
for op in all_ops:
|
|
2449
2667
|
op.place(role, rank_id)
|
|
2450
2668
|
|
|
2451
|
-
def _check_dynamic_tensor(self, set_input, net_input, index):
|
|
2452
|
-
"""
|
|
2453
|
-
Check if tensor is correctly set for dynamic shape.
|
|
2454
|
-
|
|
2455
|
-
Args:
|
|
2456
|
-
set_input (Tensor): Tensor set for dynamic shape.
|
|
2457
|
-
net_input (Tensor): Input tensor of the Cell object.
|
|
2458
|
-
index (int): Tensor index for set inputs.
|
|
2459
|
-
"""
|
|
2460
|
-
if not isinstance(net_input, Tensor):
|
|
2461
|
-
raise TypeError(
|
|
2462
|
-
f"For 'set_inputs' and tuple(list) in 'set_inputs',the type of {index + 1}th input must be Tensor, "
|
|
2463
|
-
f"but got {type(net_input)}.")
|
|
2464
|
-
is_param_set_input = isinstance(set_input, Parameter)
|
|
2465
|
-
is_param_net_input = isinstance(net_input, Parameter)
|
|
2466
|
-
if (is_param_set_input and not is_param_net_input) or (is_param_net_input and not is_param_set_input):
|
|
2467
|
-
raise TypeError(
|
|
2468
|
-
f"For 'set_inputs' and tuple(list) in 'set_inputs', the {index + 1}th input must be the same "
|
|
2469
|
-
f"as network's input, but got 'set_inputs': {type(set_input)} and network's input: {type(net_input)}.")
|
|
2470
|
-
if set_input.dtype != net_input.dtype:
|
|
2471
|
-
raise TypeError(
|
|
2472
|
-
f"For 'set_inputs' and tuple(list) in 'set_inputs',the dtype of {index + 1}th input must be the same "
|
|
2473
|
-
f"as network's input, but got 'set_inputs': {set_input.dtype} and network's input: {net_input.dtype}.")
|
|
2474
|
-
if -2 not in set_input.shape:
|
|
2475
|
-
if net_input.dim() != 0 and set_input.dim() != net_input.dim():
|
|
2476
|
-
raise ValueError(
|
|
2477
|
-
f"For 'set_inputs' and tuple(list) in 'set_inputs',the dims of {index + 1}th input must be the "
|
|
2478
|
-
f"same as network's input, but got 'set_inputs': {set_input.dim()} and network's input: "
|
|
2479
|
-
f"{net_input.dim()}.")
|
|
2480
|
-
if not all([ele1 in (-1, ele2) for ele1, ele2 in zip(set_input.shape, net_input.shape)]):
|
|
2481
|
-
raise ValueError(
|
|
2482
|
-
f"For 'set_inputs' and tuple(list) in 'set_inputs',the shape of {index + 1}th input must be the "
|
|
2483
|
-
f"same as network's input, but got 'set_inputs': {set_input.shape} and network's input: "
|
|
2484
|
-
f"{net_input.shape}.")
|
|
2485
|
-
|
|
2486
|
-
def _check_compile_dynamic_shape(self, set_inputs, net_inputs):
|
|
2487
|
-
"""
|
|
2488
|
-
Check if graph has been compiled with dynamic shape.
|
|
2489
|
-
|
|
2490
|
-
Args:
|
|
2491
|
-
net_inputs (tuple): Inputs of the Cell object.
|
|
2492
|
-
"""
|
|
2493
|
-
set_inputs_len = len(set_inputs)
|
|
2494
|
-
net_inputs_len = len(net_inputs)
|
|
2495
|
-
if set_inputs_len != net_inputs_len:
|
|
2496
|
-
raise ValueError("The length of 'set_inputs' or tuple(list) in 'set_inputs' must be equal to network's "
|
|
2497
|
-
f"inputs, but got 'set_inputs': {set_inputs_len} and network's input: {net_inputs_len}.")
|
|
2498
|
-
for index, (set_input, net_input) in enumerate(zip(set_inputs, net_inputs)):
|
|
2499
|
-
if isinstance(set_input, Tensor):
|
|
2500
|
-
self._check_dynamic_tensor(set_input, net_input, index)
|
|
2501
|
-
elif isinstance(set_input, (tuple, list)):
|
|
2502
|
-
if not isinstance(net_input, (tuple, list)):
|
|
2503
|
-
raise TypeError(
|
|
2504
|
-
f"The {index + 1}th input type of 'set_inputs' or tuple(list) in 'set_inputs' must be tuple or "
|
|
2505
|
-
f"list, but got {type(net_input)}.")
|
|
2506
|
-
self._check_compile_dynamic_shape(set_input, net_input)
|
|
2507
|
-
else:
|
|
2508
|
-
if context._get_mode() == context.PYNATIVE_MODE and set_input is None:
|
|
2509
|
-
continue
|
|
2510
|
-
if net_input != set_input:
|
|
2511
|
-
raise ValueError(
|
|
2512
|
-
f"The {index + 1}th input of 'set_inputs' or tuple(list) in 'set_inputs' must be the same with "
|
|
2513
|
-
f"network's input, but got set_inputs: {set_input} and network's input: {net_input}.")
|
|
2514
|
-
|
|
2515
|
-
def _run_tracefunc(self, *args, **kwargs):
|
|
2516
|
-
""" Run Packed Cell in Pack."""
|
|
2517
|
-
args = self._mixed_precision_cast(args)
|
|
2518
|
-
need_subgraph = hasattr(self, "bprop") or hasattr(self, "_pipeline_stage") or self.get_flags()
|
|
2519
|
-
if not PackFunc.current.is_pynative_mode and need_subgraph:
|
|
2520
|
-
expander = PackExpander.get_instance()
|
|
2521
|
-
args = expander.begin_subgraph(self, *args)
|
|
2522
|
-
args = [_convert_tensor(a) for a in args]
|
|
2523
|
-
output = self._run_construct(args, kwargs)
|
|
2524
|
-
ret = expander.end_subgraph(self, output)
|
|
2525
|
-
output = _convert_tensor(ret)
|
|
2526
|
-
else:
|
|
2527
|
-
with _SetMixedPrecision(self):
|
|
2528
|
-
output = self._run_construct(args, kwargs)
|
|
2529
|
-
return output
|
|
2530
|
-
|
|
2531
2669
|
def _mixed_precision_cast(self, inputs):
|
|
2532
2670
|
mixed_type = self.get_mixed_precision_type()
|
|
2533
2671
|
if mixed_type == MixedPrecisionType.NOTSET:
|
|
@@ -2633,7 +2771,7 @@ class GraphCell(Cell):
|
|
|
2633
2771
|
self._add_attr("graph_load_from_mindir", self.graph)
|
|
2634
2772
|
if not self.obf_random_seed:
|
|
2635
2773
|
return self.compile_and_run(*args, **kwargs)
|
|
2636
|
-
append_input = Tensor((numpy.ones((1,
|
|
2774
|
+
append_input = Tensor((numpy.ones((1,)) * self._branch_control_input).astype(numpy.int32))
|
|
2637
2775
|
return self.compile_and_run(*args, append_input, **kwargs)
|
|
2638
2776
|
|
|
2639
2777
|
|