mindspore 2.2.14__cp39-cp39-win_amd64.whl → 2.3.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +6 -5
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +76 -18
- mindspore/_extends/builtin_operations.py +2 -1
- mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
- mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
- mindspore/_extends/parse/__init__.py +18 -14
- mindspore/_extends/parse/compile_config.py +258 -0
- mindspore/_extends/parse/namespace.py +2 -2
- mindspore/_extends/parse/parser.py +174 -62
- mindspore/_extends/parse/resources.py +45 -14
- mindspore/_extends/parse/standard_method.py +142 -240
- mindspore/{ops/_op_impl/tbe/atomic_addr_clean.py → _extends/pijit/__init__.py} +6 -16
- mindspore/_extends/pijit/pijit_func_white_list.py +343 -0
- mindspore/_extends/remote/kernel_build_server.py +2 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +51 -24
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/base.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/boost/grad_freeze.py +2 -2
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/boost/less_batch_normalization.py +9 -6
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +15 -4
- mindspore/common/_jit_fallback_utils.py +2 -3
- mindspore/common/_register_for_adapter.py +7 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +8 -9
- mindspore/common/_stub_tensor.py +7 -1
- mindspore/common/_utils.py +5 -17
- mindspore/common/api.py +411 -106
- mindspore/common/auto_dynamic_shape.py +27 -14
- mindspore/common/dtype.py +17 -10
- mindspore/common/dump.py +6 -8
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +260 -0
- mindspore/common/hook_handle.py +51 -4
- mindspore/common/initializer.py +1 -1
- mindspore/common/jit_config.py +34 -14
- mindspore/common/lazy_inline.py +72 -19
- mindspore/common/mindir_util.py +12 -2
- mindspore/common/mutable.py +79 -14
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +30 -11
- mindspore/common/recompute.py +262 -0
- mindspore/common/seed.py +9 -9
- mindspore/common/sparse_tensor.py +272 -24
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +468 -494
- mindspore/communication/__init__.py +6 -11
- mindspore/communication/_comm_helper.py +5 -0
- mindspore/communication/comm_func.py +1140 -0
- mindspore/communication/management.py +115 -102
- mindspore/config/op_info.config +22 -54
- mindspore/context.py +346 -63
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +6 -6
- mindspore/dataset/audio/transforms.py +711 -158
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/engine/cache_client.py +2 -2
- mindspore/dataset/engine/datasets.py +140 -83
- mindspore/dataset/engine/datasets_audio.py +14 -14
- mindspore/dataset/engine/datasets_standard_format.py +33 -3
- mindspore/dataset/engine/datasets_text.py +38 -38
- mindspore/dataset/engine/datasets_user_defined.py +78 -59
- mindspore/dataset/engine/datasets_vision.py +77 -73
- mindspore/dataset/engine/offload.py +5 -7
- mindspore/dataset/engine/queue.py +56 -38
- mindspore/dataset/engine/validators.py +11 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +408 -121
- mindspore/dataset/text/utils.py +9 -9
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/transforms.py +261 -76
- mindspore/dataset/utils/browse_dataset.py +9 -9
- mindspore/dataset/vision/__init__.py +8 -8
- mindspore/dataset/vision/c_transforms.py +10 -10
- mindspore/dataset/vision/py_transforms_util.py +1 -1
- mindspore/dataset/vision/transforms.py +2844 -549
- mindspore/dataset/vision/utils.py +161 -10
- mindspore/dataset/vision/validators.py +14 -2
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/optim/__init__.py +12 -2
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +35 -34
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +40 -16
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +66 -121
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +15 -8
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +28 -19
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +339 -0
- mindspore/include/api/data_type.h +2 -2
- mindspore/include/api/dual_abi_helper.h +16 -3
- mindspore/include/api/model.h +4 -3
- mindspore/include/api/status.h +14 -0
- mindspore/include/c_api/model_c.h +173 -0
- mindspore/include/c_api/ms/base/types.h +1 -0
- mindspore/include/c_api/types_c.h +19 -0
- mindspore/include/dataset/execute.h +1 -3
- mindspore/include/dataset/vision.h +54 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +2 -2
- mindspore/mindrecord/__init__.py +5 -1
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +25 -0
- mindspore/mindrecord/filewriter.py +76 -58
- mindspore/mindrecord/mindpage.py +40 -6
- mindspore/mindrecord/shardutils.py +3 -2
- mindspore/mindrecord/shardwriter.py +7 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +8 -13
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -15
- mindspore/mindrecord/tools/csv_to_mr.py +4 -9
- mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
- mindspore/mindrecord/tools/mnist_to_mr.py +7 -12
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/mint/__init__.py +1137 -0
- mindspore/{rewrite/ast_transformers → mint/linalg}/__init__.py +9 -4
- mindspore/mint/nn/__init__.py +512 -0
- mindspore/mint/nn/functional.py +573 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +185 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +72 -0
- mindspore/nn/__init__.py +1 -0
- mindspore/nn/cell.py +213 -257
- mindspore/nn/dynamic_lr.py +2 -2
- mindspore/nn/extend/__init__.py +29 -0
- mindspore/nn/extend/basic.py +140 -0
- mindspore/nn/extend/embedding.py +143 -0
- mindspore/{rewrite/ast_creator_register.py → nn/extend/layer/__init__.py} +9 -19
- mindspore/nn/extend/layer/normalization.py +109 -0
- mindspore/nn/extend/pooling.py +117 -0
- mindspore/nn/layer/activation.py +83 -93
- mindspore/nn/layer/basic.py +177 -82
- mindspore/nn/layer/channel_shuffle.py +3 -16
- mindspore/nn/layer/container.py +3 -3
- mindspore/nn/layer/conv.py +75 -66
- mindspore/nn/layer/embedding.py +101 -43
- mindspore/nn/layer/embedding_service.py +531 -0
- mindspore/nn/layer/embedding_service_layer.py +393 -0
- mindspore/nn/layer/image.py +4 -7
- mindspore/nn/layer/math.py +1 -1
- mindspore/nn/layer/normalization.py +52 -66
- mindspore/nn/layer/padding.py +30 -39
- mindspore/nn/layer/pooling.py +18 -9
- mindspore/nn/layer/rnn_cells.py +6 -16
- mindspore/nn/layer/rnns.py +6 -5
- mindspore/nn/layer/thor_layer.py +1 -2
- mindspore/nn/layer/timedistributed.py +1 -1
- mindspore/nn/layer/transformer.py +52 -50
- mindspore/nn/learning_rate_schedule.py +6 -5
- mindspore/nn/loss/loss.py +62 -83
- mindspore/nn/optim/ada_grad.py +4 -2
- mindspore/nn/optim/adadelta.py +3 -1
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +102 -181
- mindspore/nn/optim/adamax.py +4 -2
- mindspore/nn/optim/adasum.py +3 -3
- mindspore/nn/optim/asgd.py +4 -2
- mindspore/nn/optim/ftrl.py +31 -61
- mindspore/nn/optim/lamb.py +5 -3
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +6 -4
- mindspore/nn/optim/momentum.py +13 -25
- mindspore/nn/optim/optimizer.py +6 -3
- mindspore/nn/optim/proximal_ada_grad.py +4 -2
- mindspore/nn/optim/rmsprop.py +9 -3
- mindspore/nn/optim/rprop.py +4 -2
- mindspore/nn/optim/sgd.py +5 -3
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
- mindspore/nn/probability/distribution/beta.py +2 -2
- mindspore/nn/probability/distribution/categorical.py +4 -6
- mindspore/nn/probability/distribution/cauchy.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +2 -2
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +13 -1
- mindspore/nn/wrap/__init__.py +2 -1
- mindspore/nn/wrap/cell_wrapper.py +58 -13
- mindspore/nn/wrap/grad_reducer.py +148 -8
- mindspore/nn/wrap/loss_scale.py +32 -9
- mindspore/numpy/__init__.py +2 -0
- mindspore/numpy/array_creations.py +2 -0
- mindspore/numpy/array_ops.py +6 -6
- mindspore/numpy/dtypes.py +3 -3
- mindspore/numpy/fft.py +431 -0
- mindspore/numpy/math_ops.py +62 -68
- mindspore/numpy/utils.py +3 -0
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +4 -129
- mindspore/ops/_grad_experimental/grad_comm_ops.py +89 -34
- mindspore/ops/_grad_experimental/grad_math_ops.py +68 -283
- mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
- mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
- mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/__init__.py +0 -1
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +1 -1
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
- mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -3
- mindspore/ops/_op_impl/cpu/adam.py +2 -2
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
- mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
- mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
- mindspore/ops/_vmap/vmap_array_ops.py +164 -101
- mindspore/ops/_vmap/vmap_base.py +8 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +143 -58
- mindspore/ops/_vmap/vmap_image_ops.py +70 -13
- mindspore/ops/_vmap/vmap_math_ops.py +130 -58
- mindspore/ops/_vmap/vmap_nn_ops.py +249 -115
- mindspore/ops/_vmap/vmap_other_ops.py +1 -1
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +231 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +250 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +980 -0
- mindspore/ops/auto_generate/gen_ops_def.py +6443 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +13167 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +429 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +121 -23
- mindspore/ops/composite/math_ops.py +10 -49
- mindspore/ops/composite/multitype_ops/_compile_utils.py +191 -618
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +25 -134
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
- mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
- mindspore/ops/deprecated.py +14 -3
- mindspore/ops/extend/__init__.py +53 -0
- mindspore/ops/extend/array_func.py +218 -0
- mindspore/ops/extend/math_func.py +76 -0
- mindspore/ops/extend/nn_func.py +308 -0
- mindspore/ops/function/__init__.py +31 -11
- mindspore/ops/function/array_func.py +846 -1735
- mindspore/ops/function/clip_func.py +19 -31
- mindspore/ops/function/debug_func.py +1 -4
- mindspore/ops/function/fft_func.py +31 -0
- mindspore/ops/function/grad/grad_func.py +27 -20
- mindspore/ops/function/image_func.py +27 -21
- mindspore/ops/function/linalg_func.py +35 -68
- mindspore/ops/function/math_func.py +913 -2791
- mindspore/ops/function/nn_func.py +1439 -885
- mindspore/ops/function/other_func.py +6 -7
- mindspore/ops/function/parameter_func.py +5 -93
- mindspore/ops/function/random_func.py +254 -108
- mindspore/ops/function/reshard_func.py +102 -0
- mindspore/ops/function/sparse_func.py +4 -4
- mindspore/ops/function/sparse_unary_func.py +9 -16
- mindspore/ops/function/spectral_func.py +1 -1
- mindspore/ops/function/vmap_func.py +14 -14
- mindspore/ops/functional.py +342 -343
- mindspore/ops/op_info_register.py +16 -43
- mindspore/ops/operations/__init__.py +32 -23
- mindspore/ops/operations/_grad_ops.py +21 -853
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +107 -518
- mindspore/ops/operations/_rl_inner_ops.py +2 -2
- mindspore/ops/operations/_scalar_ops.py +5 -480
- mindspore/ops/operations/_sequence_ops.py +6 -36
- mindspore/ops/operations/_tensor_array.py +8 -8
- mindspore/ops/operations/array_ops.py +108 -2705
- mindspore/ops/operations/comm_ops.py +801 -118
- mindspore/ops/operations/custom_ops.py +61 -120
- mindspore/ops/operations/debug_ops.py +104 -35
- mindspore/ops/operations/image_ops.py +1 -217
- mindspore/ops/operations/inner_ops.py +5 -40
- mindspore/ops/operations/linalg_ops.py +1 -49
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +61 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2016 -0
- mindspore/ops/operations/math_ops.py +572 -4667
- mindspore/ops/operations/nn_ops.py +248 -2162
- mindspore/ops/operations/other_ops.py +53 -45
- mindspore/ops/operations/random_ops.py +4 -53
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/sparse_ops.py +4 -4
- mindspore/ops/primitive.py +204 -103
- mindspore/ops/silent_check.py +5 -5
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +250 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_ops.py +1084 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +968 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +138 -0
- mindspore/ops_generate/pyboost_utils.py +354 -0
- mindspore/ops_generate/template.py +239 -0
- mindspore/parallel/__init__.py +6 -4
- mindspore/parallel/_auto_parallel_context.py +73 -3
- mindspore/parallel/_cell_wrapper.py +16 -9
- mindspore/parallel/_cost_model_context.py +1 -1
- mindspore/parallel/_dp_allreduce_fusion.py +159 -159
- mindspore/parallel/_parallel_serialization.py +29 -13
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +18 -11
- mindspore/parallel/_transformer/__init__.py +1 -1
- mindspore/parallel/_transformer/layers.py +1 -1
- mindspore/parallel/_transformer/loss.py +1 -1
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/op_parallel_config.py +1 -1
- mindspore/parallel/_transformer/transformer.py +2 -2
- mindspore/parallel/_utils.py +161 -6
- mindspore/parallel/algo_parameter_config.py +6 -8
- mindspore/parallel/checkpoint_transform.py +191 -32
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +344 -0
- mindspore/parallel/cluster/process_entity/_utils.py +126 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +1 -1
- mindspore/parallel/mpi/_mpi_config.py +1 -1
- mindspore/parallel/parameter_broadcast.py +152 -0
- mindspore/parallel/shard.py +128 -17
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +3 -2
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/util.py +125 -0
- mindspore/profiler/envprofiling.py +2 -2
- mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
- mindspore/profiler/parser/ascend_analysis/constant.py +53 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +159 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +161 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +85 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +57 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +116 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +68 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +14 -9
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -1
- mindspore/profiler/parser/ascend_flops_generator.py +20 -4
- mindspore/profiler/parser/ascend_hccl_generator.py +29 -278
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +147 -146
- mindspore/profiler/parser/ascend_msprof_generator.py +73 -283
- mindspore/profiler/parser/ascend_op_generator.py +92 -42
- mindspore/profiler/parser/ascend_timeline_generator.py +296 -133
- mindspore/profiler/parser/base_timeline_generator.py +6 -0
- mindspore/profiler/parser/framework_parser.py +3 -2
- mindspore/profiler/parser/integrator.py +3 -1
- mindspore/profiler/parser/minddata_parser.py +72 -3
- mindspore/profiler/parser/msadvisor_analyzer.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +16 -1
- mindspore/profiler/profiling.py +445 -190
- mindspore/rewrite/__init__.py +2 -13
- mindspore/rewrite/api/node.py +122 -36
- mindspore/rewrite/api/pattern_engine.py +2 -3
- mindspore/rewrite/api/scoped_value.py +16 -15
- mindspore/rewrite/api/symbol_tree.py +45 -29
- mindspore/rewrite/ast_helpers/__init__.py +3 -6
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
- mindspore/rewrite/common/__init__.py +1 -2
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
- mindspore/rewrite/{namer.py → common/namer.py} +63 -18
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/node/__init__.py +5 -5
- mindspore/rewrite/node/call_function.py +23 -7
- mindspore/rewrite/node/cell_container.py +7 -3
- mindspore/rewrite/node/control_flow.py +53 -28
- mindspore/rewrite/node/node.py +212 -196
- mindspore/rewrite/node/node_manager.py +51 -22
- mindspore/rewrite/node/node_topological_manager.py +3 -23
- mindspore/rewrite/parsers/__init__.py +12 -0
- mindspore/rewrite/parsers/arguments_parser.py +8 -9
- mindspore/rewrite/parsers/assign_parser.py +637 -413
- mindspore/rewrite/parsers/attribute_parser.py +3 -4
- mindspore/rewrite/parsers/class_def_parser.py +115 -148
- mindspore/rewrite/parsers/constant_parser.py +5 -5
- mindspore/rewrite/parsers/container_parser.py +4 -6
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +31 -98
- mindspore/rewrite/parsers/function_def_parser.py +13 -5
- mindspore/rewrite/parsers/if_parser.py +28 -10
- mindspore/rewrite/parsers/module_parser.py +8 -182
- mindspore/rewrite/parsers/parser.py +1 -5
- mindspore/rewrite/parsers/parser_register.py +1 -1
- mindspore/rewrite/parsers/return_parser.py +5 -10
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +704 -185
- mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
- mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
- mindspore/run_check/_check_version.py +6 -14
- mindspore/run_check/run_check.py +1 -1
- mindspore/safeguard/rewrite_obfuscation.py +9 -19
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +6 -5
- mindspore/train/_utils.py +178 -4
- mindspore/train/amp.py +167 -245
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/__init__.py +5 -2
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +143 -29
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_flops_collector.py +238 -0
- mindspore/train/callback/_landscape.py +15 -9
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_mindio_ttp.py +443 -0
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
- mindspore/train/callback/_summary_collector.py +7 -7
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/data_sink.py +6 -5
- mindspore/train/dataset_helper.py +60 -21
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/accuracy.py +7 -7
- mindspore/train/metrics/confusion_matrix.py +8 -6
- mindspore/train/metrics/cosine_similarity.py +6 -4
- mindspore/train/metrics/error.py +2 -2
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/perplexity.py +2 -1
- mindspore/train/metrics/topk.py +2 -2
- mindspore/train/mind_ir_pb2.py +89 -15
- mindspore/train/model.py +290 -60
- mindspore/train/serialization.py +495 -220
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +51 -28
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.2.14.dist-info → mindspore-2.3.0.dist-info}/METADATA +3 -3
- mindspore-2.3.0.dist-info/RECORD +1400 -0
- {mindspore-2.2.14.dist-info → mindspore-2.3.0.dist-info}/entry_points.txt +1 -0
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
- mindspore/gen_ops.py +0 -273
- mindspore/nn/layer/flash_attention.py +0 -189
- mindspore/ops/_op_impl/cpu/concat.py +0 -39
- mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
- mindspore/ops/_op_impl/tbe/__init__.py +0 -47
- mindspore/ops/_op_impl/tbe/abs.py +0 -38
- mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/acos.py +0 -37
- mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/acosh.py +0 -37
- mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
- mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
- mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
- mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
- mindspore/ops/_op_impl/tbe/add.py +0 -42
- mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/add_n.py +0 -39
- mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
- mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
- mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
- mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
- mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
- mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
- mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
- mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/asin.py +0 -37
- mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/asinh.py +0 -37
- mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/assign.py +0 -79
- mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
- mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
- mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/atan.py +0 -37
- mindspore/ops/_op_impl/tbe/atan2.py +0 -38
- mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/atanh.py +0 -37
- mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
- mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
- mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
- mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
- mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
- mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
- mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
- mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
- mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
- mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
- mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cast.py +0 -55
- mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/cdist.py +0 -38
- mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/ceil.py +0 -37
- mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/celu.py +0 -39
- mindspore/ops/_op_impl/tbe/centralization.py +0 -39
- mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
- mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/concat.py +0 -40
- mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
- mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
- mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
- mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
- mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
- mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/cos.py +0 -37
- mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/cosh.py +0 -37
- mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
- mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cummin.py +0 -41
- mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
- mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
- mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
- mindspore/ops/_op_impl/tbe/diag.py +0 -38
- mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
- mindspore/ops/_op_impl/tbe/dilation.py +0 -40
- mindspore/ops/_op_impl/tbe/div.py +0 -41
- mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
- mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
- mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
- mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
- mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
- mindspore/ops/_op_impl/tbe/elu.py +0 -38
- mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/equal.py +0 -42
- mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/erf.py +0 -37
- mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfc.py +0 -37
- mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
- mindspore/ops/_op_impl/tbe/exp.py +0 -40
- mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
- mindspore/ops/_op_impl/tbe/expm1.py +0 -37
- mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
- mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/fill.py +0 -56
- mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/flatten.py +0 -48
- mindspore/ops/_op_impl/tbe/floor.py +0 -37
- mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
- mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
- mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
- mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
- mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
- mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
- mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
- mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/ger.py +0 -43
- mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/greater.py +0 -43
- mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
- mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
- mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
- mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
- mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
- mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
- mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
- mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/im2col.py +0 -42
- mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
- mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
- mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/inv.py +0 -38
- mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/invert.py +0 -37
- mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/iou.py +0 -38
- mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/is_close.py +0 -40
- mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
- mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
- mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
- mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
- mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
- mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
- mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/lerp.py +0 -38
- mindspore/ops/_op_impl/tbe/less.py +0 -41
- mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/log.py +0 -40
- mindspore/ops/_op_impl/tbe/log1p.py +0 -37
- mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
- mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
- mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
- mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
- mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/matmul.py +0 -53
- mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
- mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
- mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
- mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum.py +0 -39
- mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
- mindspore/ops/_op_impl/tbe/minimum.py +0 -40
- mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mish.py +0 -37
- mindspore/ops/_op_impl/tbe/mod.py +0 -41
- mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/mul.py +0 -37
- mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
- mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
- mindspore/ops/_op_impl/tbe/neg.py +0 -39
- mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
- mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
- mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
- mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
- mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
- mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/pack.py +0 -58
- mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
- mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
- mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/pdist.py +0 -36
- mindspore/ops/_op_impl/tbe/pooling.py +0 -46
- mindspore/ops/_op_impl/tbe/population_count.py +0 -38
- mindspore/ops/_op_impl/tbe/pow.py +0 -41
- mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/prelu.py +0 -37
- mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/range.py +0 -39
- mindspore/ops/_op_impl/tbe/real_div.py +0 -38
- mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
- mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
- mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
- mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
- mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6.py +0 -38
- mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/renorm.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
- mindspore/ops/_op_impl/tbe/rint.py +0 -37
- mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roll.py +0 -42
- mindspore/ops/_op_impl/tbe/round.py +0 -38
- mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
- mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
- mindspore/ops/_op_impl/tbe/select.py +0 -38
- mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/selu.py +0 -39
- mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sgd.py +0 -62
- mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sign.py +0 -38
- mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/sin.py +0 -37
- mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sinh.py +0 -37
- mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/slice.py +0 -58
- mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
- mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax.py +0 -37
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
- mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/softplus.py +0 -37
- mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softsign.py +0 -37
- mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sort.py +0 -38
- mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/split_d.py +0 -38
- mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/split_v.py +0 -39
- mindspore/ops/_op_impl/tbe/splitv.py +0 -39
- mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/square.py +0 -38
- mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
- mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
- mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
- mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
- mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
- mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
- mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
- mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
- mindspore/ops/_op_impl/tbe/sub.py +0 -39
- mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tan.py +0 -38
- mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh.py +0 -37
- mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
- mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
- mindspore/ops/_op_impl/tbe/tile.py +0 -37
- mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
- mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
- mindspore/ops/_op_impl/tbe/transpose.py +0 -60
- mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
- mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
- mindspore/ops/_op_impl/tbe/trunc.py +0 -39
- mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/unpack.py +0 -38
- mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
- mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
- mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
- mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
- mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
- mindspore/ops/_tracefunc.py +0 -241
- mindspore/ops/arg_dtype_cast.py +0 -54
- mindspore/rewrite/api/tree_node_helper.py +0 -60
- mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
- mindspore/rewrite/namespace.py +0 -53
- mindspore-2.2.14.dist-info/RECORD +0 -1924
- {mindspore-2.2.14.dist-info → mindspore-2.3.0.dist-info}/WHEEL +0 -0
- {mindspore-2.2.14.dist-info → mindspore-2.3.0.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2020-
|
|
1
|
+
# Copyright 2020-2024 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -20,10 +20,9 @@ import inspect
|
|
|
20
20
|
import os
|
|
21
21
|
import time
|
|
22
22
|
from collections import OrderedDict
|
|
23
|
-
from types import FunctionType, MethodType
|
|
24
23
|
import numpy
|
|
25
24
|
|
|
26
|
-
from mindspore._checkparam import args_type_check
|
|
25
|
+
from mindspore._checkparam import args_type_check, check_hook_fn
|
|
27
26
|
from mindspore.common._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic
|
|
28
27
|
from mindspore import log as logger
|
|
29
28
|
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
|
|
@@ -34,7 +33,8 @@ from mindspore._c_expression import init_pipeline, update_func_graph_hyper_param
|
|
|
34
33
|
from mindspore import _checkparam as Validator
|
|
35
34
|
from mindspore.common import dtype as mstype
|
|
36
35
|
from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache
|
|
37
|
-
from mindspore.common.api import _generate_branch_control_input
|
|
36
|
+
from mindspore.common.api import _generate_branch_control_input, _convert_python_data, _get_args_for_run_predict
|
|
37
|
+
from mindspore.common.api import _process_dyn_args, _generate_dyn_compile_args
|
|
38
38
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
39
39
|
from mindspore.common.tensor import Tensor
|
|
40
40
|
from mindspore.ops.operations import Cast
|
|
@@ -43,15 +43,7 @@ from mindspore.ops.operations import _inner_ops as inner
|
|
|
43
43
|
from mindspore.parallel.shard import Shard
|
|
44
44
|
from mindspore._check_jit_forbidden_api import jit_forbidden_register
|
|
45
45
|
from mindspore.common._decorator import deprecated
|
|
46
|
-
from mindspore.
|
|
47
|
-
from mindspore.ops._tracefunc import _convert_tensor, _SetMixedPrecision, PackFunc
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
def _check_args(args):
|
|
51
|
-
"""Check the input args's type"""
|
|
52
|
-
for item in args:
|
|
53
|
-
if isinstance(item, Tensor) and item.has_init:
|
|
54
|
-
item.init_data()
|
|
46
|
+
from mindspore.common._register_for_recompute import recompute_registry
|
|
55
47
|
|
|
56
48
|
|
|
57
49
|
class Cell(Cell_):
|
|
@@ -89,7 +81,7 @@ class Cell(Cell_):
|
|
|
89
81
|
|
|
90
82
|
Examples:
|
|
91
83
|
>>> import mindspore.nn as nn
|
|
92
|
-
>>>
|
|
84
|
+
>>> from mindspore import ops
|
|
93
85
|
>>> class MyCell(nn.Cell):
|
|
94
86
|
... def __init__(self, forward_net):
|
|
95
87
|
... super(MyCell, self).__init__(auto_prefix=False)
|
|
@@ -109,17 +101,19 @@ class Cell(Cell_):
|
|
|
109
101
|
"""
|
|
110
102
|
|
|
111
103
|
IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_create_time',
|
|
112
|
-
'_func_graph_flags', '_parameter_layout_dict', '_params_list', '
|
|
104
|
+
'_func_graph_flags', '_parameter_layout_dict', '_params_list', '_phase',
|
|
113
105
|
'_forward_pre_hook', '_forward_hook', '_enable_forward_pre_hook', '_enable_forward_hook',
|
|
114
106
|
'_bprop_debug', '_enable_backward_hook', '_cell_backward_hook', '_is_run', '_param_prefix',
|
|
115
107
|
'_attr_synced', 'pynative', 'requires_grad', 'cell_type']
|
|
108
|
+
total_instance_count = 0
|
|
116
109
|
|
|
117
110
|
def __init__(self, auto_prefix=True, flags=None):
|
|
118
111
|
Cell_.__init__(self, self._cell_tag)
|
|
112
|
+
Cell.total_instance_count += 1
|
|
113
|
+
self.instance_count = Cell.total_instance_count
|
|
119
114
|
self._params = OrderedDict()
|
|
120
115
|
self._cells = OrderedDict()
|
|
121
116
|
self._params_list = OrderedDict()
|
|
122
|
-
self._tensor_list = OrderedDict()
|
|
123
117
|
self._primitives = OrderedDict()
|
|
124
118
|
self.training = False
|
|
125
119
|
self.requires_grad = False
|
|
@@ -135,11 +129,14 @@ class Cell(Cell_):
|
|
|
135
129
|
self._create_time = int(time.time() * 1e9)
|
|
136
130
|
self.arguments_key = ""
|
|
137
131
|
self.compile_cache = set()
|
|
132
|
+
self.phase_cache = dict()
|
|
138
133
|
cells_compile_cache[id(self)] = self.compile_cache
|
|
139
134
|
self.parameter_broadcast_done = False
|
|
140
135
|
self._id = 1
|
|
141
136
|
self.exist_names = set("")
|
|
142
137
|
self.exist_objs = set()
|
|
138
|
+
self.recompute_cell = None
|
|
139
|
+
self.sig = inspect.signature(self.construct)
|
|
143
140
|
init_pipeline()
|
|
144
141
|
|
|
145
142
|
# call gc to release GE session resources used by non-used cell objects
|
|
@@ -161,12 +158,14 @@ class Cell(Cell_):
|
|
|
161
158
|
self._has_config_recompute = False
|
|
162
159
|
self._user_parameters = []
|
|
163
160
|
self._dynamic_shape_inputs = None
|
|
161
|
+
self._compile_args = None
|
|
164
162
|
self.saved_dynamic_shape = None
|
|
165
163
|
self._jit_config_dict = dict()
|
|
166
164
|
self.grad_ops_label = False
|
|
167
165
|
self.ge_sync_data = False
|
|
168
166
|
self._is_check_and_refresh = False
|
|
169
167
|
self._amp_level = ""
|
|
168
|
+
self._init_flag = False
|
|
170
169
|
|
|
171
170
|
def __getstate__(self):
|
|
172
171
|
base = Cell_.__getstate__(self)
|
|
@@ -225,7 +224,7 @@ class Cell(Cell_):
|
|
|
225
224
|
|
|
226
225
|
Tutorial Examples:
|
|
227
226
|
- `Cell and Parameter - Custom Cell Reverse
|
|
228
|
-
<https://mindspore.cn/tutorials/en/
|
|
227
|
+
<https://mindspore.cn/tutorials/en/master/advanced/modules/layer.html#custom-cell-reverse>`_
|
|
229
228
|
"""
|
|
230
229
|
return self._bprop_debug
|
|
231
230
|
|
|
@@ -317,10 +316,23 @@ class Cell(Cell_):
|
|
|
317
316
|
|
|
318
317
|
@property
|
|
319
318
|
def pipeline_stage(self):
|
|
319
|
+
"""
|
|
320
|
+
`pipeline_stage` represents the pipeline stage of current Cell.
|
|
321
|
+
"""
|
|
320
322
|
return self._pipeline_stage
|
|
321
323
|
|
|
322
324
|
@pipeline_stage.setter
|
|
323
325
|
def pipeline_stage(self, value):
|
|
326
|
+
"""
|
|
327
|
+
Set the `pipeline_stage` of a Cell.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
value (int): The pipeline stage of a parameter.
|
|
331
|
+
|
|
332
|
+
Raises:
|
|
333
|
+
TypeError: If `value` is not int type or is a bool type.
|
|
334
|
+
ValueError: If `value` is not a positive integer.
|
|
335
|
+
"""
|
|
324
336
|
if not isinstance(value, int) or isinstance(value, bool):
|
|
325
337
|
raise TypeError("For 'Cell', the property 'pipeline_stage' "
|
|
326
338
|
"must be int type, but got type : {}".format(type(value)))
|
|
@@ -376,10 +388,6 @@ class Cell(Cell_):
|
|
|
376
388
|
cells = self.__dict__['_cells']
|
|
377
389
|
if name in cells:
|
|
378
390
|
return cells[name]
|
|
379
|
-
if '_tensor_list' in self.__dict__:
|
|
380
|
-
tensor_list = self.__dict__['_tensor_list']
|
|
381
|
-
if name in tensor_list:
|
|
382
|
-
return tensor_list[name]
|
|
383
391
|
if '_params_list' in self.__dict__:
|
|
384
392
|
params_list = self.__dict__['_params_list']
|
|
385
393
|
if name in params_list:
|
|
@@ -391,12 +399,11 @@ class Cell(Cell_):
|
|
|
391
399
|
# while deepcopy a cell instance, the copied cell instance can't be added to cells_compile_cache
|
|
392
400
|
# here using pop(id(self), None) to avoid KeyError exception
|
|
393
401
|
cells_compile_cache.pop(id(self), None)
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
f"Please use 'super().__init__()'.") from e
|
|
402
|
+
if hasattr(self, "compile_cache") and self.compile_cache:
|
|
403
|
+
_cell_graph_executor.del_net_res(self, self.compile_cache)
|
|
404
|
+
if isinstance(self, GraphCell):
|
|
405
|
+
_cell_graph_executor.dec_graph_cell_count()
|
|
406
|
+
Cell.total_instance_count -= 1
|
|
400
407
|
|
|
401
408
|
def __delattr__(self, name):
|
|
402
409
|
if name in self._params:
|
|
@@ -405,8 +412,6 @@ class Cell(Cell_):
|
|
|
405
412
|
del self._cells[name]
|
|
406
413
|
elif '_params_list' in self.__dict__ and name in self._params_list:
|
|
407
414
|
del self._params_list[name]
|
|
408
|
-
elif '_tensor_list' in self.__dict__ and name in self._tensor_list:
|
|
409
|
-
del self._tensor_list[name]
|
|
410
415
|
else:
|
|
411
416
|
object.__delattr__(self, name)
|
|
412
417
|
self._attr_synced = False
|
|
@@ -420,7 +425,7 @@ class Cell(Cell_):
|
|
|
420
425
|
elif isinstance(item, float):
|
|
421
426
|
res.append(self.cast(item, dst_type))
|
|
422
427
|
elif hasattr(item, "dtype") and item.dtype in \
|
|
423
|
-
|
|
428
|
+
{mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16} and item.dtype != dst_type:
|
|
424
429
|
res.append(self.cast(item, dst_type))
|
|
425
430
|
else:
|
|
426
431
|
res.append(item)
|
|
@@ -479,7 +484,10 @@ class Cell(Cell_):
|
|
|
479
484
|
elif hasattr(self, "_shard_fn"):
|
|
480
485
|
output = self._shard_fn(*cast_inputs, **kwargs)
|
|
481
486
|
else:
|
|
482
|
-
|
|
487
|
+
if self.recompute_cell is not None:
|
|
488
|
+
output = self.recompute_cell(*cast_inputs, **kwargs)
|
|
489
|
+
else:
|
|
490
|
+
output = self.construct(*cast_inputs, **kwargs)
|
|
483
491
|
if self._enable_forward_hook:
|
|
484
492
|
output = self._run_forward_hook(cast_inputs, output)
|
|
485
493
|
return output
|
|
@@ -566,8 +574,9 @@ class Cell(Cell_):
|
|
|
566
574
|
def shard(self, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
|
|
567
575
|
"""
|
|
568
576
|
Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be
|
|
569
|
-
generated by sharding propagation. In PyNative mode, use this method
|
|
570
|
-
to specify
|
|
577
|
+
generated by sharding propagation. In PyNative mode, use this method to specify a Cell for distributed
|
|
578
|
+
execution in graph mode. In Graph mode, use this method to specify distribution strategy for a Cell,
|
|
579
|
+
strategy for others will be set by sharding propagation.
|
|
571
580
|
in_strategy and out_strategy define the input and output layout respectively.
|
|
572
581
|
in_strategy/out_strategy should be a tuple, each element of which corresponds to the desired layout of
|
|
573
582
|
this input/output, and None represents data_parallel,
|
|
@@ -575,8 +584,8 @@ class Cell(Cell_):
|
|
|
575
584
|
The parallel strategies of remaining operators are derived from the strategy specified by the input and output.
|
|
576
585
|
|
|
577
586
|
Note:
|
|
578
|
-
|
|
579
|
-
|
|
587
|
+
If Cell.shard is called, the parallel mode in `set_auto_parallel_context` (parallel_mode) will be set to
|
|
588
|
+
"auto_parallel" and the search mode (search_mode) to "sharding_propagation".
|
|
580
589
|
If the input contain Parameter, its strategy should be set in `in_strategy`.
|
|
581
590
|
|
|
582
591
|
Args:
|
|
@@ -598,7 +607,7 @@ class Cell(Cell_):
|
|
|
598
607
|
use right now. Support [ ``"0"`` , ``"1"`` , ``"2"`` ]. Default: ``0`` .
|
|
599
608
|
|
|
600
609
|
Returns:
|
|
601
|
-
|
|
610
|
+
Function, return the cell construct function that will be executed under auto parallel process.
|
|
602
611
|
|
|
603
612
|
Examples:
|
|
604
613
|
>>> import mindspore.nn as nn
|
|
@@ -616,22 +625,21 @@ class Cell(Cell_):
|
|
|
616
625
|
... def __init__(self):
|
|
617
626
|
... self.block1 = Block()
|
|
618
627
|
... self.block2 = Block()
|
|
619
|
-
... self.block2.shard(in_strategy=((2, 1),), out_strategy=(None,),
|
|
620
|
-
...
|
|
628
|
+
... self.block2_shard = self.block2.shard(in_strategy=((2, 1),), out_strategy=(None,),
|
|
629
|
+
... parameter_plan={'self.block2.shard.dense1.weight': (4, 1)})
|
|
621
630
|
... def construct(self, x):
|
|
622
631
|
... x = self.block1(x)
|
|
623
|
-
... x = self.
|
|
632
|
+
... x = self.block2_shard(x)
|
|
624
633
|
... return x
|
|
625
634
|
"""
|
|
626
|
-
if context.
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
f"Please check if you call Cell.shard in the script.")
|
|
635
|
+
if context.get_auto_parallel_context("parallel_mode") not in ["auto_parallel", "semi_auto_parallel"]:
|
|
636
|
+
raise AssertionError(f"Cell shard only supports auto parallel or semi_auto_parallel "
|
|
637
|
+
f"Please check the parallel mode in parallel context.")
|
|
630
638
|
|
|
631
639
|
shard_fn = Shard()
|
|
632
640
|
fn = shard_fn(self, in_strategy, out_strategy, parameter_plan, device, level)
|
|
633
641
|
object.__setattr__(self, "_shard_fn", fn)
|
|
634
|
-
return
|
|
642
|
+
return fn
|
|
635
643
|
|
|
636
644
|
def auto_cast_inputs(self, inputs):
|
|
637
645
|
"""
|
|
@@ -654,48 +662,56 @@ class Cell(Cell_):
|
|
|
654
662
|
|
|
655
663
|
return cast_inputs
|
|
656
664
|
|
|
657
|
-
def
|
|
658
|
-
|
|
659
|
-
|
|
665
|
+
def _init_check(self):
|
|
666
|
+
for param in self.get_parameters(expand=False):
|
|
667
|
+
if param.has_init:
|
|
668
|
+
param.init_data()
|
|
660
669
|
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
bound_arguments.apply_defaults()
|
|
664
|
-
args = bound_arguments.args
|
|
665
|
-
kwargs = bound_arguments.kwargs
|
|
666
|
-
|
|
667
|
-
if PackFunc.is_tracing():
|
|
668
|
-
return self._run_tracefunc(*args, **kwargs)
|
|
669
|
-
|
|
670
|
-
if hasattr(self, '_is_check_and_refresh') and not self._is_check_and_refresh:
|
|
670
|
+
def _self_check(self):
|
|
671
|
+
if not self._is_check_and_refresh:
|
|
671
672
|
self.check_names_and_refresh_name()
|
|
672
673
|
self._is_check_and_refresh = True
|
|
673
674
|
|
|
675
|
+
def _predict(self, *args, **kwargs):
|
|
676
|
+
if not hasattr(self, "phase"):
|
|
677
|
+
return False, None
|
|
678
|
+
if (self.phase == "prefill" or self.phase == 'increment') and self.phase in self.phase_cache:
|
|
679
|
+
new_args = _get_args_for_run_predict(self, args, kwargs, self._compile_args)
|
|
680
|
+
res = _cell_graph_executor._graph_executor(tuple(new_args), self.phase_cache[self.phase])
|
|
681
|
+
res = _convert_python_data(res)
|
|
682
|
+
return True, res
|
|
683
|
+
return False, None
|
|
684
|
+
|
|
685
|
+
def __call__(self, *args, **kwargs):
|
|
674
686
|
# Run in Graph mode.
|
|
675
687
|
if os.getenv("MS_JIT") != '0' and context._get_mode() == context.GRAPH_MODE:
|
|
688
|
+
if kwargs:
|
|
689
|
+
bound_arguments = self.sig.bind(*args, **kwargs)
|
|
690
|
+
bound_arguments.apply_defaults()
|
|
691
|
+
args = bound_arguments.args
|
|
692
|
+
kwargs = bound_arguments.kwargs
|
|
693
|
+
|
|
694
|
+
predict_compiled, res = self._predict(*args, **kwargs)
|
|
695
|
+
if predict_compiled:
|
|
696
|
+
return res
|
|
676
697
|
self._check_construct_args(*args)
|
|
698
|
+
|
|
677
699
|
if self._hook_fn_registered():
|
|
678
700
|
logger.warning(f"For 'Cell', it's not support hook function in graph mode. If you want to use hook "
|
|
679
701
|
f"function, please use context.set_context to set pynative mode.")
|
|
702
|
+
self._self_check()
|
|
680
703
|
out = self.compile_and_run(*args, **kwargs)
|
|
681
704
|
return out
|
|
682
705
|
|
|
683
706
|
# Run in PyNative mode.
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
self._do_parameter_broadcast()
|
|
707
|
+
self._self_check()
|
|
708
|
+
if not self._init_flag:
|
|
709
|
+
self._init_check()
|
|
710
|
+
self._init_flag = True
|
|
689
711
|
|
|
690
|
-
|
|
691
|
-
self._check_cell_flags_in_pynative()
|
|
692
|
-
|
|
693
|
-
if self.requires_grad and _pynative_executor.enable_grad():
|
|
712
|
+
if self.requires_grad:
|
|
694
713
|
_pynative_executor.set_grad_flag(True)
|
|
695
714
|
|
|
696
|
-
if self._dynamic_shape_inputs is not None:
|
|
697
|
-
self._check_compile_dynamic_shape(self._dynamic_shape_inputs, args)
|
|
698
|
-
|
|
699
715
|
try:
|
|
700
716
|
_pynative_executor.new_graph(self, *args, **kwargs)
|
|
701
717
|
output = self._run_construct(args, kwargs)
|
|
@@ -704,16 +720,8 @@ class Cell(Cell_):
|
|
|
704
720
|
_pynative_executor.clear_res()
|
|
705
721
|
raise err
|
|
706
722
|
|
|
707
|
-
if isinstance(output, Parameter):
|
|
708
|
-
output = output.data
|
|
709
723
|
return output
|
|
710
724
|
|
|
711
|
-
def _check_cell_flags_in_pynative(self):
|
|
712
|
-
"""Check the flags added to cell in pynative mode"""
|
|
713
|
-
if hasattr(self, "_func_graph_flags") and self._func_graph_flags.get("output_no_recompute"):
|
|
714
|
-
raise TypeError("Recompute is not supported in PyNative mode currently, you can use "
|
|
715
|
-
"'context.set_context(mode=context.GRAPH_MODE)' or @jit to set graph mode.")
|
|
716
|
-
|
|
717
725
|
def _add_attr(self, name, value):
|
|
718
726
|
if name and name[:2] != '__' and name not in Cell.IGNORE_LIST:
|
|
719
727
|
super(Cell, self)._add_attr(name, value)
|
|
@@ -830,15 +838,6 @@ class Cell(Cell_):
|
|
|
830
838
|
else:
|
|
831
839
|
self.insert_param_to_cell(name, None)
|
|
832
840
|
|
|
833
|
-
def _set_attr_for_tensor(self, name, value):
|
|
834
|
-
if context._get_mode() == context.PYNATIVE_MODE:
|
|
835
|
-
tensor_list = self.__dict__.get('_tensor_list')
|
|
836
|
-
if name in self.__dict__:
|
|
837
|
-
del self.__dict__[name]
|
|
838
|
-
tensor_list[name] = value
|
|
839
|
-
else:
|
|
840
|
-
object.__setattr__(self, name, value)
|
|
841
|
-
|
|
842
841
|
def __setattr__(self, name, value):
|
|
843
842
|
cells = self.__dict__.get('_cells')
|
|
844
843
|
params = self.__dict__.get('_params')
|
|
@@ -856,8 +855,6 @@ class Cell(Cell_):
|
|
|
856
855
|
if value is not None:
|
|
857
856
|
raise TypeError(f"For 'Cell', the type of {name} must be cell, but got {type(value).__name__}.")
|
|
858
857
|
self._cells[name] = None
|
|
859
|
-
elif isinstance(value, Tensor):
|
|
860
|
-
self._set_attr_for_tensor(name, value)
|
|
861
858
|
else:
|
|
862
859
|
if isinstance(value, Primitive):
|
|
863
860
|
value.set_prim_instance_name(name)
|
|
@@ -910,14 +907,25 @@ class Cell(Cell_):
|
|
|
910
907
|
"""
|
|
911
908
|
logger.warning("'set_parallel_input_with_inputs' function is deprecated.")
|
|
912
909
|
|
|
913
|
-
def set_inputs(self, *inputs):
|
|
910
|
+
def set_inputs(self, *inputs, **kwargs):
|
|
914
911
|
"""
|
|
915
912
|
Save set inputs for computation graph. The number of inputs should be the same with that of the datasets. When
|
|
916
913
|
using Model for dynamic shape, please make sure that all networks and loss functions passed to the Model are
|
|
917
|
-
configured with set_inputs. The
|
|
914
|
+
configured with set_inputs. The shape of input Tensor can be either dynamic or static.
|
|
915
|
+
|
|
916
|
+
.. note::
|
|
917
|
+
There are two mode:
|
|
918
|
+
|
|
919
|
+
- Full mode: arguments will be used as all compile inputs for graph-compiling.
|
|
920
|
+
- Incremental mode: arguments will set to some of the Cell inputs, which will be substituted into the input
|
|
921
|
+
at the corresponding position for graph-compiling.
|
|
922
|
+
|
|
923
|
+
Only one of inputs or kwargs can be set. Inputs for full mode and kwargs for incremental mode.
|
|
918
924
|
|
|
919
925
|
Args:
|
|
920
|
-
inputs (tuple):
|
|
926
|
+
inputs (tuple): Full mode arguments.
|
|
927
|
+
kwargs (dict): Incremental mode arguments. The acceptable key is the name of parameter defined
|
|
928
|
+
in `self.construct`.
|
|
921
929
|
|
|
922
930
|
.. warning::
|
|
923
931
|
This is an experimental API that is subject to change or deletion.
|
|
@@ -937,16 +945,27 @@ class Cell(Cell_):
|
|
|
937
945
|
>>> net = ReluNet()
|
|
938
946
|
>>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32)
|
|
939
947
|
>>> net.set_inputs(input_dyn)
|
|
940
|
-
>>>
|
|
941
|
-
>>> output = net(
|
|
948
|
+
>>> input = Tensor(np.random.random([3, 10]), dtype=ms.float32)
|
|
949
|
+
>>> output = net(input)
|
|
950
|
+
>>>
|
|
951
|
+
>>> net2 = ReluNet()
|
|
952
|
+
>>> net2.set_inputs(x=input_dyn)
|
|
953
|
+
>>> output = net2(input)
|
|
942
954
|
"""
|
|
943
955
|
if self.grad_ops_label:
|
|
944
956
|
logger.warning(f'For Cell, set_inputs must be set before the gradient function of the network is '
|
|
945
957
|
f'generated.')
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
958
|
+
if kwargs and inputs:
|
|
959
|
+
raise ValueError('For Cell, set_inputs should only set inputs or kwargs(inputs: %s, kwargs: %s)!'
|
|
960
|
+
% (inputs, kwargs))
|
|
961
|
+
|
|
962
|
+
if not kwargs:
|
|
963
|
+
self._dynamic_shape_inputs = inputs
|
|
964
|
+
self._check_construct_args(*inputs)
|
|
965
|
+
if context._get_mode() == context.PYNATIVE_MODE:
|
|
966
|
+
_pynative_executor.set_dynamic_input(self, *self._dynamic_shape_inputs)
|
|
967
|
+
else:
|
|
968
|
+
self._dynamic_shape_inputs = _process_dyn_args(self.construct, kwargs)
|
|
950
969
|
|
|
951
970
|
def get_inputs(self):
|
|
952
971
|
"""
|
|
@@ -981,6 +1000,48 @@ class Cell(Cell_):
|
|
|
981
1000
|
|
|
982
1001
|
return self._dynamic_shape_inputs
|
|
983
1002
|
|
|
1003
|
+
def _check_parameter_consistency(self, set_inputs, net_inputs):
|
|
1004
|
+
"""Check consistency for parameter."""
|
|
1005
|
+
for index, (set_input, net_input) in enumerate(zip(set_inputs, net_inputs)):
|
|
1006
|
+
if isinstance(set_input, Tensor):
|
|
1007
|
+
if not isinstance(net_input, Tensor):
|
|
1008
|
+
raise TypeError(
|
|
1009
|
+
f"For 'set_inputs' and tuple(list) in 'set_inputs',the type of {index + 1}th input must "
|
|
1010
|
+
f"be Tensor, but got {type(net_input)}.")
|
|
1011
|
+
if isinstance(set_input, Parameter) != isinstance(net_input, Parameter):
|
|
1012
|
+
raise TypeError(
|
|
1013
|
+
f"For 'set_inputs' and tuple(list) in 'set_inputs', the {index + 1}th input must be the same "
|
|
1014
|
+
f"as expected, but got expected: {type(set_input)} and input: {type(net_input)}.")
|
|
1015
|
+
elif isinstance(set_input, (tuple, list)):
|
|
1016
|
+
if not isinstance(net_input, (tuple, list)):
|
|
1017
|
+
raise TypeError(
|
|
1018
|
+
f"The {index + 1}th input type of 'set_inputs' or tuple(list) in "
|
|
1019
|
+
f"'set_inputs' must be tuple or list, but got {type(net_input)}.")
|
|
1020
|
+
self._check_parameter_consistency(set_input, net_input)
|
|
1021
|
+
|
|
1022
|
+
def _get_compile_args(self, args):
|
|
1023
|
+
"""Get compile arguments."""
|
|
1024
|
+
# this is used only for test
|
|
1025
|
+
set_by_auto_dynamic = False
|
|
1026
|
+
if is_auto_dynamic():
|
|
1027
|
+
if self._dynamic_shape_inputs is None:
|
|
1028
|
+
set_by_auto_dynamic = True
|
|
1029
|
+
else:
|
|
1030
|
+
if isinstance(self._dynamic_shape_inputs, (list, tuple)) and self._dynamic_shape_inputs[0] is None:
|
|
1031
|
+
set_by_auto_dynamic = True
|
|
1032
|
+
if set_by_auto_dynamic:
|
|
1033
|
+
self._dynamic_shape_inputs = convert_inputs_to_dynamic(*args)
|
|
1034
|
+
|
|
1035
|
+
if self._dynamic_shape_inputs is not None:
|
|
1036
|
+
logger.debug("Compiled Graph with dynamic shape")
|
|
1037
|
+
compile_args = _generate_dyn_compile_args(args, self._dynamic_shape_inputs)
|
|
1038
|
+
_cell_graph_executor._graph_executor.check_argument_consistency(compile_args, args, "set_inputs")
|
|
1039
|
+
self._check_parameter_consistency(compile_args, args)
|
|
1040
|
+
Validator.check_symbolic_shape(compile_args, args)
|
|
1041
|
+
self.saved_dynamic_shape = compile_args
|
|
1042
|
+
return compile_args
|
|
1043
|
+
return args
|
|
1044
|
+
|
|
984
1045
|
def compile(self, *args, **kwargs):
|
|
985
1046
|
"""
|
|
986
1047
|
Compile Cell as a computation graph, the input must be consistent with the input defined in construct.
|
|
@@ -989,19 +1050,9 @@ class Cell(Cell_):
|
|
|
989
1050
|
args (tuple): Args of the Cell object.
|
|
990
1051
|
kwargs (dict): Kwargs of the Cell object.
|
|
991
1052
|
"""
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
if self._dynamic_shape_inputs is None:
|
|
997
|
-
_cell_graph_executor.compile(self, phase=self.phase,
|
|
998
|
-
jit_config_dict=self._jit_config_dict, *args, **kwargs)
|
|
999
|
-
else:
|
|
1000
|
-
self._check_compile_dynamic_shape(self._dynamic_shape_inputs, args)
|
|
1001
|
-
self.saved_dynamic_shape = self._dynamic_shape_inputs
|
|
1002
|
-
_cell_graph_executor.compile(self, *self._dynamic_shape_inputs, phase=self.phase,
|
|
1003
|
-
jit_config_dict=self._jit_config_dict, **kwargs)
|
|
1004
|
-
logger.debug("Compiled Graph with dynamic shape")
|
|
1053
|
+
self._compile_args = self._get_compile_args(args)
|
|
1054
|
+
_cell_graph_executor.compile(self, *self._compile_args, phase=self.phase,
|
|
1055
|
+
jit_config_dict=self._jit_config_dict, **kwargs)
|
|
1005
1056
|
|
|
1006
1057
|
def compile_and_run(self, *args, **kwargs):
|
|
1007
1058
|
"""
|
|
@@ -1019,7 +1070,7 @@ class Cell(Cell_):
|
|
|
1019
1070
|
"""
|
|
1020
1071
|
self.compile(*args, **kwargs)
|
|
1021
1072
|
self.add_flags(ge_sync_data=False)
|
|
1022
|
-
new_args = _get_args_for_run(self, args, kwargs)
|
|
1073
|
+
new_args = _get_args_for_run(self, args, kwargs, self._compile_args)
|
|
1023
1074
|
return _cell_graph_executor(self, *new_args, phase=self.phase)
|
|
1024
1075
|
|
|
1025
1076
|
def auto_parallel_compile_and_run(self):
|
|
@@ -1033,6 +1084,7 @@ class Cell(Cell_):
|
|
|
1033
1084
|
|
|
1034
1085
|
def exec_checkpoint_graph(self):
|
|
1035
1086
|
"""Executes GE saving checkpoint graph operation."""
|
|
1087
|
+
logger.warning("'exec_checkpoint_graph' function is deprecated.")
|
|
1036
1088
|
self.add_flags(ge_sync_data=True)
|
|
1037
1089
|
_cell_graph_executor(self, phase='save')
|
|
1038
1090
|
|
|
@@ -1070,14 +1122,14 @@ class Cell(Cell_):
|
|
|
1070
1122
|
Parameter(name=bias, shape=(3,), dtype=Int64, requires_grad=True)
|
|
1071
1123
|
"""
|
|
1072
1124
|
if not param_name:
|
|
1073
|
-
raise KeyError("For 'insert_param_to_cell', the argument 'param_name' should not be None.")
|
|
1125
|
+
raise KeyError(f"For 'insert_param_to_cell', the argument 'param_name' should not be None.")
|
|
1074
1126
|
if check_name_contain_dot and '.' in param_name:
|
|
1075
|
-
raise KeyError("For 'insert_param_to_cell', the argument 'param_name' should not contain
|
|
1127
|
+
raise KeyError(f"For 'insert_param_to_cell', the argument 'param_name' should not contain'.' ")
|
|
1076
1128
|
if '_params' not in self.__dict__:
|
|
1077
|
-
raise AttributeError("For 'insert_param_to_cell', please call Cell.__init__() firstly.")
|
|
1129
|
+
raise AttributeError(f"For 'insert_param_to_cell', please call Cell.__init__() firstly.")
|
|
1078
1130
|
if hasattr(self, param_name) and param_name not in self._params:
|
|
1079
|
-
raise KeyError("For 'insert_param_to_cell', the {} parameter already exists in the network.
|
|
1080
|
-
"insert another parameter with the same name."
|
|
1131
|
+
raise KeyError(f"For 'insert_param_to_cell', the {param_name} parameter already exists in the network."
|
|
1132
|
+
f"Cannot insert another parameter with the same name.")
|
|
1081
1133
|
if not isinstance(param, Parameter) and param is not None:
|
|
1082
1134
|
raise TypeError(f"For 'insert_param_to_cell', the argument 'param' must be 'Parameter' if not None, "
|
|
1083
1135
|
f"but got {type(param)}.")
|
|
@@ -1139,11 +1191,11 @@ class Cell(Cell_):
|
|
|
1139
1191
|
raise TypeError(f"For 'insert_child_to_cell', the type of parameter 'child_name' must be str, "
|
|
1140
1192
|
f"but got {type(child_name)}.")
|
|
1141
1193
|
if not child_name or '.' in child_name:
|
|
1142
|
-
raise KeyError("For 'insert_child_to_cell', the parameter 'child_name' can not be None and "
|
|
1143
|
-
"can not contain '.'")
|
|
1194
|
+
raise KeyError(f"For 'insert_child_to_cell', the parameter 'child_name' can not be None and "
|
|
1195
|
+
"can not contain '.' ")
|
|
1144
1196
|
if hasattr(self, child_name) and child_name not in self._cells:
|
|
1145
|
-
raise KeyError("For 'insert_child_to_cell', the {} child cell already exists in the network.
|
|
1146
|
-
"insert another child cell with the same name."
|
|
1197
|
+
raise KeyError(f"For 'insert_child_to_cell', the {child_name} child cell already exists in the network."
|
|
1198
|
+
f"Cannot insert another child cell with the same name.")
|
|
1147
1199
|
if not isinstance(child_cell, Cell) and child_cell is not None:
|
|
1148
1200
|
raise TypeError(f"For 'insert_child_to_cell', the argument 'child_cell' must be 'Cell' if not None, "
|
|
1149
1201
|
f"but got type {type(child_cell)}.")
|
|
@@ -1163,7 +1215,7 @@ class Cell(Cell_):
|
|
|
1163
1215
|
Returns:
|
|
1164
1216
|
Tensor, returns the computed result.
|
|
1165
1217
|
"""
|
|
1166
|
-
|
|
1218
|
+
raise AttributeError("For 'Cell', the method 'construct' is not defined.")
|
|
1167
1219
|
|
|
1168
1220
|
def remove_redundant_parameters(self):
|
|
1169
1221
|
"""
|
|
@@ -1361,7 +1413,7 @@ class Cell(Cell_):
|
|
|
1361
1413
|
|
|
1362
1414
|
Tutorial Examples:
|
|
1363
1415
|
- `Model Training - Optimizer
|
|
1364
|
-
<https://mindspore.cn/tutorials/en/
|
|
1416
|
+
<https://mindspore.cn/tutorials/en/master/beginner/train.html#optimizer>`_
|
|
1365
1417
|
"""
|
|
1366
1418
|
return list(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse)))
|
|
1367
1419
|
|
|
@@ -1472,7 +1524,7 @@ class Cell(Cell_):
|
|
|
1472
1524
|
|
|
1473
1525
|
Tutorial Examples:
|
|
1474
1526
|
- `Building a Network - Model Parameters
|
|
1475
|
-
<https://mindspore.cn/tutorials/en/
|
|
1527
|
+
<https://mindspore.cn/tutorials/en/master/beginner/model.html#model-parameters>`_
|
|
1476
1528
|
"""
|
|
1477
1529
|
cells = []
|
|
1478
1530
|
if expand:
|
|
@@ -1698,6 +1750,9 @@ class Cell(Cell_):
|
|
|
1698
1750
|
if not hasattr(self, "_func_graph_flags"):
|
|
1699
1751
|
self._func_graph_flags = {}
|
|
1700
1752
|
self._func_graph_flags.update({**flags})
|
|
1753
|
+
if context._get_mode() == context.PYNATIVE_MODE and self._func_graph_flags.get("output_no_recompute"):
|
|
1754
|
+
raise TypeError("Recompute is not supported in PyNative mode currently, you can use "
|
|
1755
|
+
"'context.set_context(mode=context.GRAPH_MODE)' or @jit to set graph mode.")
|
|
1701
1756
|
self.__dict__.update({**flags})
|
|
1702
1757
|
self._add_mixed_precision_flag(**flags)
|
|
1703
1758
|
return self
|
|
@@ -1808,7 +1863,7 @@ class Cell(Cell_):
|
|
|
1808
1863
|
accelerate the algorithm in the algorithm library.
|
|
1809
1864
|
|
|
1810
1865
|
If `boost_type` is not in the algorithm library, please view the algorithm in the algorithm library through
|
|
1811
|
-
`algorithm library <https://gitee.com/mindspore/mindspore/tree/
|
|
1866
|
+
`algorithm library <https://gitee.com/mindspore/mindspore/tree/master/mindspore/python/mindspore/boost>`_.
|
|
1812
1867
|
|
|
1813
1868
|
Note:
|
|
1814
1869
|
Some acceleration algorithms may affect the accuracy of the network, please choose carefully.
|
|
@@ -1865,7 +1920,7 @@ class Cell(Cell_):
|
|
|
1865
1920
|
|
|
1866
1921
|
Tutorial Examples:
|
|
1867
1922
|
- `Model Training - Implementing Training and Evaluation
|
|
1868
|
-
<https://mindspore.cn/tutorials/en/
|
|
1923
|
+
<https://mindspore.cn/tutorials/en/master/beginner/train.html#training-and-evaluation>`_
|
|
1869
1924
|
"""
|
|
1870
1925
|
if mode:
|
|
1871
1926
|
self._phase = 'train'
|
|
@@ -1945,11 +2000,11 @@ class Cell(Cell_):
|
|
|
1945
2000
|
Note:
|
|
1946
2001
|
- The `register_forward_pre_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
|
|
1947
2002
|
- 'hook_fn' must be defined as the following code.
|
|
1948
|
-
`
|
|
2003
|
+
`cell` is the object of registered Cell. `inputs` is the forward
|
|
1949
2004
|
input objects passed to the Cell. The 'hook_fn' can modify the forward input objects by returning new
|
|
1950
2005
|
forward input objects.
|
|
1951
2006
|
- It should have the following signature:
|
|
1952
|
-
hook_fn(
|
|
2007
|
+
hook_fn(cell, inputs) -> new input objects or none.
|
|
1953
2008
|
- In order to prevent running failed when switching to graph mode, it is not recommended to write it in the
|
|
1954
2009
|
`construct` function of Cell object. In the pynative mode, if the `register_forward_pre_hook` function is
|
|
1955
2010
|
called in the `construct` function of the Cell object, a hook function will be added at each run time of
|
|
@@ -1959,8 +2014,8 @@ class Cell(Cell_):
|
|
|
1959
2014
|
hook_fn (function): Python function. Forward pre hook function.
|
|
1960
2015
|
|
|
1961
2016
|
Returns:
|
|
1962
|
-
|
|
1963
|
-
|
|
2017
|
+
A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
|
|
2018
|
+
`handle.remove()` .
|
|
1964
2019
|
|
|
1965
2020
|
Raises:
|
|
1966
2021
|
TypeError: If the `hook_fn` is not a function of python.
|
|
@@ -1973,7 +2028,7 @@ class Cell(Cell_):
|
|
|
1973
2028
|
>>> import mindspore as ms
|
|
1974
2029
|
>>> from mindspore import Tensor, nn, ops
|
|
1975
2030
|
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
|
|
1976
|
-
>>> def forward_pre_hook_fn(
|
|
2031
|
+
>>> def forward_pre_hook_fn(cell, inputs):
|
|
1977
2032
|
... print("forward inputs: ", inputs)
|
|
1978
2033
|
...
|
|
1979
2034
|
>>> class Net(nn.Cell):
|
|
@@ -1995,17 +2050,8 @@ class Cell(Cell_):
|
|
|
1995
2050
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
|
|
1996
2051
|
value= [ 2.00000000e+00]))
|
|
1997
2052
|
"""
|
|
1998
|
-
if
|
|
1999
|
-
logger.warning(f"'register_forward_pre_hook' function is only supported in pynative mode, you can use "
|
|
2000
|
-
f"context.set_context to set pynative mode.")
|
|
2053
|
+
if not check_hook_fn("register_forward_pre_hook", hook_fn):
|
|
2001
2054
|
return HookHandle()
|
|
2002
|
-
|
|
2003
|
-
if not isinstance(hook_fn, (FunctionType, MethodType)):
|
|
2004
|
-
raise TypeError(f"When using 'register_forward_pre_hook(hook_fn)', the type of 'hook_fn' must be python "
|
|
2005
|
-
f"function, but got {type(hook_fn)}.")
|
|
2006
|
-
if hook_fn.__code__.co_name == "staging_specialize":
|
|
2007
|
-
raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@jit' is not supported.")
|
|
2008
|
-
|
|
2009
2055
|
self._enable_forward_pre_hook = True
|
|
2010
2056
|
_pynative_executor.set_hook_changed(self)
|
|
2011
2057
|
if not hasattr(self, '_forward_pre_hook_key'):
|
|
@@ -2028,9 +2074,8 @@ class Cell(Cell_):
|
|
|
2028
2074
|
Supported Platforms:
|
|
2029
2075
|
``Ascend`` ``GPU`` ``CPU``
|
|
2030
2076
|
"""
|
|
2031
|
-
cell_id = self.cls_name + "(" + str(id(self)) + ")"
|
|
2032
2077
|
for fn in self._forward_pre_hook.values():
|
|
2033
|
-
ret = fn(
|
|
2078
|
+
ret = fn(self, inputs)
|
|
2034
2079
|
if ret is not None:
|
|
2035
2080
|
if not isinstance(ret, tuple):
|
|
2036
2081
|
inputs = (ret,)
|
|
@@ -2045,11 +2090,11 @@ class Cell(Cell_):
|
|
|
2045
2090
|
Note:
|
|
2046
2091
|
- The `register_forward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
|
|
2047
2092
|
- 'hook_fn' must be defined as the following code.
|
|
2048
|
-
`
|
|
2093
|
+
`cell` is the object of registered Cell. `inputs` is the forward
|
|
2049
2094
|
input objects passed to the Cell. `output` is the forward output object of the Cell. The 'hook_fn' can
|
|
2050
2095
|
modify the forward output object by returning new forward output object.
|
|
2051
2096
|
- It should have the following signature:
|
|
2052
|
-
hook_fn(
|
|
2097
|
+
hook_fn(cell, inputs, output) -> new output object or none.
|
|
2053
2098
|
- In order to prevent running failed when switching to graph mode, it is not recommended to write it in the
|
|
2054
2099
|
`construct` function of Cell object. In the pynative mode, if the `register_forward_hook` function is
|
|
2055
2100
|
called in the `construct` function of the Cell object, a hook function will be added at each run time of
|
|
@@ -2059,8 +2104,8 @@ class Cell(Cell_):
|
|
|
2059
2104
|
hook_fn (function): Python function. Forward hook function.
|
|
2060
2105
|
|
|
2061
2106
|
Returns:
|
|
2062
|
-
|
|
2063
|
-
|
|
2107
|
+
A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
|
|
2108
|
+
`handle.remove()` .
|
|
2064
2109
|
|
|
2065
2110
|
Raises:
|
|
2066
2111
|
TypeError: If the `hook_fn` is not a function of python.
|
|
@@ -2073,7 +2118,7 @@ class Cell(Cell_):
|
|
|
2073
2118
|
>>> import mindspore as ms
|
|
2074
2119
|
>>> from mindspore import Tensor, nn, ops
|
|
2075
2120
|
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
|
|
2076
|
-
>>> def forward_hook_fn(
|
|
2121
|
+
>>> def forward_hook_fn(cell, inputs, output):
|
|
2077
2122
|
... print("forward inputs: ", inputs)
|
|
2078
2123
|
... print("forward output: ", output)
|
|
2079
2124
|
...
|
|
@@ -2097,17 +2142,8 @@ class Cell(Cell_):
|
|
|
2097
2142
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
|
|
2098
2143
|
value= [ 2.00000000e+00]))
|
|
2099
2144
|
"""
|
|
2100
|
-
if
|
|
2101
|
-
logger.warning(f"'register_forward_hook' function is only supported in pynative mode, you can use "
|
|
2102
|
-
f"context.set_context to set pynative mode.")
|
|
2145
|
+
if not check_hook_fn("register_forward_hook", hook_fn):
|
|
2103
2146
|
return HookHandle()
|
|
2104
|
-
|
|
2105
|
-
if not isinstance(hook_fn, (FunctionType, MethodType)):
|
|
2106
|
-
raise TypeError(f"When using 'register_forward_hook(hook_fn)', the type of 'hook_fn' must be python "
|
|
2107
|
-
f"function, but got {type(hook_fn)}.")
|
|
2108
|
-
if hook_fn.__code__.co_name == "staging_specialize":
|
|
2109
|
-
raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@jit' is not supported.")
|
|
2110
|
-
|
|
2111
2147
|
self._enable_forward_hook = True
|
|
2112
2148
|
_pynative_executor.set_hook_changed(self)
|
|
2113
2149
|
if not hasattr(self, '_forward_hook_key'):
|
|
@@ -2131,9 +2167,8 @@ class Cell(Cell_):
|
|
|
2131
2167
|
Supported Platforms:
|
|
2132
2168
|
``Ascend`` ``GPU`` ``CPU``
|
|
2133
2169
|
"""
|
|
2134
|
-
cell_id = self.cls_name + "(" + str(id(self)) + ")"
|
|
2135
2170
|
for fn in self._forward_hook.values():
|
|
2136
|
-
ret = fn(
|
|
2171
|
+
ret = fn(self, inputs, output)
|
|
2137
2172
|
if ret is not None:
|
|
2138
2173
|
output = ret
|
|
2139
2174
|
return output
|
|
@@ -2159,8 +2194,8 @@ class Cell(Cell_):
|
|
|
2159
2194
|
hook_fn (function): Python function. Backward hook function.
|
|
2160
2195
|
|
|
2161
2196
|
Returns:
|
|
2162
|
-
|
|
2163
|
-
|
|
2197
|
+
A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
|
|
2198
|
+
`handle.remove()` .
|
|
2164
2199
|
|
|
2165
2200
|
Raises:
|
|
2166
2201
|
TypeError: If the `hook_fn` is not a function of python.
|
|
@@ -2195,14 +2230,8 @@ class Cell(Cell_):
|
|
|
2195
2230
|
>>> print(output)
|
|
2196
2231
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
|
|
2197
2232
|
"""
|
|
2198
|
-
if
|
|
2199
|
-
logger.warning(f"'register_backward_hook' function is only supported in pynative mode, you can use "
|
|
2200
|
-
f"context.set_context to set pynative mode.")
|
|
2233
|
+
if not check_hook_fn("register_backward_hook", hook_fn):
|
|
2201
2234
|
return HookHandle()
|
|
2202
|
-
|
|
2203
|
-
if not isinstance(hook_fn, (FunctionType, MethodType)):
|
|
2204
|
-
raise TypeError(f"When using 'register_backward_hook(hook_fn)', the type of 'hook_fn' must be python "
|
|
2205
|
-
f"function, but got {type(hook_fn)}.")
|
|
2206
2235
|
if self._cell_backward_hook is None:
|
|
2207
2236
|
self._enable_backward_hook = True
|
|
2208
2237
|
self._cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")")
|
|
@@ -2232,10 +2261,16 @@ class Cell(Cell_):
|
|
|
2232
2261
|
else:
|
|
2233
2262
|
inputs = self._cell_backward_hook(*inputs)
|
|
2234
2263
|
inputs = (inputs,)
|
|
2235
|
-
if
|
|
2236
|
-
|
|
2264
|
+
if self.recompute_cell is not None:
|
|
2265
|
+
if isinstance(inputs, tuple):
|
|
2266
|
+
outputs = self.recompute_cell(*inputs, **kwargs)
|
|
2267
|
+
else:
|
|
2268
|
+
outputs = self.recompute_cell(inputs, **kwargs)
|
|
2237
2269
|
else:
|
|
2238
|
-
|
|
2270
|
+
if isinstance(inputs, tuple):
|
|
2271
|
+
outputs = self.construct(*inputs, **kwargs)
|
|
2272
|
+
else:
|
|
2273
|
+
outputs = self.construct(inputs, **kwargs)
|
|
2239
2274
|
outputs = self._cell_backward_hook(outputs)
|
|
2240
2275
|
return outputs
|
|
2241
2276
|
|
|
@@ -2365,6 +2400,9 @@ class Cell(Cell_):
|
|
|
2365
2400
|
introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode.
|
|
2366
2401
|
Default: ``False`` .
|
|
2367
2402
|
"""
|
|
2403
|
+
if context.get_context("mode") == context.PYNATIVE_MODE:
|
|
2404
|
+
self.recompute_cell = recompute_registry.get()(self.construct)
|
|
2405
|
+
return
|
|
2368
2406
|
self._recompute()
|
|
2369
2407
|
if 'mp_comm_recompute' in kwargs.keys():
|
|
2370
2408
|
self._mp_comm_recompute(kwargs.get('mp_comm_recompute', False))
|
|
@@ -2384,16 +2422,13 @@ class Cell(Cell_):
|
|
|
2384
2422
|
"the key kwargs must be 'mp_comm_recompute', "
|
|
2385
2423
|
"'parallel_optimizer_comm_recompute', 'recompute_slice_activation'" % key)
|
|
2386
2424
|
|
|
2425
|
+
@deprecated("2.3", "infer_param_pipeline_stage")
|
|
2387
2426
|
def infer_param_pipeline_stage(self):
|
|
2388
2427
|
"""
|
|
2389
2428
|
Infer pipeline stages of all parameters in the cell.
|
|
2390
2429
|
|
|
2391
2430
|
Note:
|
|
2392
|
-
-
|
|
2393
|
-
the parameter should use add_pipeline_stage to add it's pipeline_stage information.
|
|
2394
|
-
- If a parameter P has been used by two operators in different stages "stageA" and "stageB",
|
|
2395
|
-
the parameter P should use P.add_pipeline_stage(stageA) and P.add_pipeline_stage(stageB)
|
|
2396
|
-
to add it's stage information before using infer_param_pipeline_stage.
|
|
2431
|
+
- The interface is deprecated from version 2.3 and will be removed in a future version.
|
|
2397
2432
|
|
|
2398
2433
|
Returns:
|
|
2399
2434
|
The params belong to current stage in pipeline parallel.
|
|
@@ -2448,86 +2483,6 @@ class Cell(Cell_):
|
|
|
2448
2483
|
for op in all_ops:
|
|
2449
2484
|
op.place(role, rank_id)
|
|
2450
2485
|
|
|
2451
|
-
def _check_dynamic_tensor(self, set_input, net_input, index):
|
|
2452
|
-
"""
|
|
2453
|
-
Check if tensor is correctly set for dynamic shape.
|
|
2454
|
-
|
|
2455
|
-
Args:
|
|
2456
|
-
set_input (Tensor): Tensor set for dynamic shape.
|
|
2457
|
-
net_input (Tensor): Input tensor of the Cell object.
|
|
2458
|
-
index (int): Tensor index for set inputs.
|
|
2459
|
-
"""
|
|
2460
|
-
if not isinstance(net_input, Tensor):
|
|
2461
|
-
raise TypeError(
|
|
2462
|
-
f"For 'set_inputs' and tuple(list) in 'set_inputs',the type of {index + 1}th input must be Tensor, "
|
|
2463
|
-
f"but got {type(net_input)}.")
|
|
2464
|
-
is_param_set_input = isinstance(set_input, Parameter)
|
|
2465
|
-
is_param_net_input = isinstance(net_input, Parameter)
|
|
2466
|
-
if (is_param_set_input and not is_param_net_input) or (is_param_net_input and not is_param_set_input):
|
|
2467
|
-
raise TypeError(
|
|
2468
|
-
f"For 'set_inputs' and tuple(list) in 'set_inputs', the {index + 1}th input must be the same "
|
|
2469
|
-
f"as network's input, but got 'set_inputs': {type(set_input)} and network's input: {type(net_input)}.")
|
|
2470
|
-
if set_input.dtype != net_input.dtype:
|
|
2471
|
-
raise TypeError(
|
|
2472
|
-
f"For 'set_inputs' and tuple(list) in 'set_inputs',the dtype of {index + 1}th input must be the same "
|
|
2473
|
-
f"as network's input, but got 'set_inputs': {set_input.dtype} and network's input: {net_input.dtype}.")
|
|
2474
|
-
if -2 not in set_input.shape:
|
|
2475
|
-
if net_input.dim() != 0 and set_input.dim() != net_input.dim():
|
|
2476
|
-
raise ValueError(
|
|
2477
|
-
f"For 'set_inputs' and tuple(list) in 'set_inputs',the dims of {index + 1}th input must be the "
|
|
2478
|
-
f"same as network's input, but got 'set_inputs': {set_input.dim()} and network's input: "
|
|
2479
|
-
f"{net_input.dim()}.")
|
|
2480
|
-
if not all([ele1 in (-1, ele2) for ele1, ele2 in zip(set_input.shape, net_input.shape)]):
|
|
2481
|
-
raise ValueError(
|
|
2482
|
-
f"For 'set_inputs' and tuple(list) in 'set_inputs',the shape of {index + 1}th input must be the "
|
|
2483
|
-
f"same as network's input, but got 'set_inputs': {set_input.shape} and network's input: "
|
|
2484
|
-
f"{net_input.shape}.")
|
|
2485
|
-
|
|
2486
|
-
def _check_compile_dynamic_shape(self, set_inputs, net_inputs):
|
|
2487
|
-
"""
|
|
2488
|
-
Check if graph has been compiled with dynamic shape.
|
|
2489
|
-
|
|
2490
|
-
Args:
|
|
2491
|
-
net_inputs (tuple): Inputs of the Cell object.
|
|
2492
|
-
"""
|
|
2493
|
-
set_inputs_len = len(set_inputs)
|
|
2494
|
-
net_inputs_len = len(net_inputs)
|
|
2495
|
-
if set_inputs_len != net_inputs_len:
|
|
2496
|
-
raise ValueError("The length of 'set_inputs' or tuple(list) in 'set_inputs' must be equal to network's "
|
|
2497
|
-
f"inputs, but got 'set_inputs': {set_inputs_len} and network's input: {net_inputs_len}.")
|
|
2498
|
-
for index, (set_input, net_input) in enumerate(zip(set_inputs, net_inputs)):
|
|
2499
|
-
if isinstance(set_input, Tensor):
|
|
2500
|
-
self._check_dynamic_tensor(set_input, net_input, index)
|
|
2501
|
-
elif isinstance(set_input, (tuple, list)):
|
|
2502
|
-
if not isinstance(net_input, (tuple, list)):
|
|
2503
|
-
raise TypeError(
|
|
2504
|
-
f"The {index + 1}th input type of 'set_inputs' or tuple(list) in 'set_inputs' must be tuple or "
|
|
2505
|
-
f"list, but got {type(net_input)}.")
|
|
2506
|
-
self._check_compile_dynamic_shape(set_input, net_input)
|
|
2507
|
-
else:
|
|
2508
|
-
if context._get_mode() == context.PYNATIVE_MODE and set_input is None:
|
|
2509
|
-
continue
|
|
2510
|
-
if net_input != set_input:
|
|
2511
|
-
raise ValueError(
|
|
2512
|
-
f"The {index + 1}th input of 'set_inputs' or tuple(list) in 'set_inputs' must be the same with "
|
|
2513
|
-
f"network's input, but got set_inputs: {set_input} and network's input: {net_input}.")
|
|
2514
|
-
|
|
2515
|
-
def _run_tracefunc(self, *args, **kwargs):
|
|
2516
|
-
""" Run Packed Cell in Pack."""
|
|
2517
|
-
args = self._mixed_precision_cast(args)
|
|
2518
|
-
need_subgraph = hasattr(self, "bprop") or hasattr(self, "_pipeline_stage") or self.get_flags()
|
|
2519
|
-
if not PackFunc.current.is_pynative_mode and need_subgraph:
|
|
2520
|
-
expander = PackExpander.get_instance()
|
|
2521
|
-
args = expander.begin_subgraph(self, *args)
|
|
2522
|
-
args = [_convert_tensor(a) for a in args]
|
|
2523
|
-
output = self._run_construct(args, kwargs)
|
|
2524
|
-
ret = expander.end_subgraph(self, output)
|
|
2525
|
-
output = _convert_tensor(ret)
|
|
2526
|
-
else:
|
|
2527
|
-
with _SetMixedPrecision(self):
|
|
2528
|
-
output = self._run_construct(args, kwargs)
|
|
2529
|
-
return output
|
|
2530
|
-
|
|
2531
2486
|
def _mixed_precision_cast(self, inputs):
|
|
2532
2487
|
mixed_type = self.get_mixed_precision_type()
|
|
2533
2488
|
if mixed_type == MixedPrecisionType.NOTSET:
|
|
@@ -2624,6 +2579,7 @@ class GraphCell(Cell):
|
|
|
2624
2579
|
params_dict = update_func_graph_hyper_params(self.graph, params_init)
|
|
2625
2580
|
for name, param in params_dict.items():
|
|
2626
2581
|
self._params[name] = param
|
|
2582
|
+
_cell_graph_executor.inc_graph_cell_count()
|
|
2627
2583
|
|
|
2628
2584
|
def construct(self, *inputs):
|
|
2629
2585
|
return self.graph(*inputs)
|
|
@@ -2633,7 +2589,7 @@ class GraphCell(Cell):
|
|
|
2633
2589
|
self._add_attr("graph_load_from_mindir", self.graph)
|
|
2634
2590
|
if not self.obf_random_seed:
|
|
2635
2591
|
return self.compile_and_run(*args, **kwargs)
|
|
2636
|
-
append_input = Tensor((numpy.ones((1,
|
|
2592
|
+
append_input = Tensor((numpy.ones((1,)) * self._branch_control_input).astype(numpy.int32))
|
|
2637
2593
|
return self.compile_and_run(*args, append_input, **kwargs)
|
|
2638
2594
|
|
|
2639
2595
|
|