mindspore 2.2.11__cp39-cp39-win_amd64.whl → 2.3.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +7 -5
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +76 -18
- mindspore/_extends/builtin_operations.py +2 -1
- mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
- mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
- mindspore/_extends/parse/__init__.py +18 -14
- mindspore/_extends/parse/compile_config.py +258 -0
- mindspore/_extends/parse/namespace.py +2 -2
- mindspore/_extends/parse/parser.py +174 -62
- mindspore/_extends/parse/resources.py +45 -14
- mindspore/_extends/parse/standard_method.py +142 -240
- mindspore/{ops/_op_impl/tbe/atomic_addr_clean.py → _extends/pijit/__init__.py} +6 -16
- mindspore/_extends/pijit/pijit_func_white_list.py +343 -0
- mindspore/_extends/remote/kernel_build_server.py +2 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +51 -24
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/base.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/boost/grad_freeze.py +2 -2
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/boost/less_batch_normalization.py +9 -6
- mindspore/common/__init__.py +15 -4
- mindspore/common/_jit_fallback_utils.py +2 -3
- mindspore/common/_register_for_adapter.py +7 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +8 -9
- mindspore/common/_stub_tensor.py +7 -1
- mindspore/common/_utils.py +5 -17
- mindspore/common/api.py +411 -106
- mindspore/common/auto_dynamic_shape.py +27 -14
- mindspore/common/dtype.py +17 -10
- mindspore/common/dump.py +6 -8
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +260 -0
- mindspore/common/hook_handle.py +51 -4
- mindspore/common/initializer.py +1 -1
- mindspore/common/jit_config.py +34 -14
- mindspore/common/lazy_inline.py +72 -19
- mindspore/common/mindir_util.py +12 -2
- mindspore/common/mutable.py +79 -14
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +30 -11
- mindspore/common/recompute.py +262 -0
- mindspore/common/seed.py +9 -9
- mindspore/common/sparse_tensor.py +272 -24
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +468 -496
- mindspore/communication/__init__.py +6 -11
- mindspore/communication/_comm_helper.py +5 -0
- mindspore/communication/comm_func.py +1140 -0
- mindspore/communication/management.py +118 -102
- mindspore/config/op_info.config +22 -54
- mindspore/context.py +378 -65
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +6 -6
- mindspore/dataset/audio/transforms.py +711 -158
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/engine/cache_client.py +2 -2
- mindspore/dataset/engine/datasets.py +163 -83
- mindspore/dataset/engine/datasets_audio.py +14 -14
- mindspore/dataset/engine/datasets_standard_format.py +33 -3
- mindspore/dataset/engine/datasets_text.py +38 -38
- mindspore/dataset/engine/datasets_user_defined.py +78 -59
- mindspore/dataset/engine/datasets_vision.py +77 -73
- mindspore/dataset/engine/offload.py +5 -7
- mindspore/dataset/engine/queue.py +56 -38
- mindspore/dataset/engine/validators.py +11 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +408 -121
- mindspore/dataset/text/utils.py +9 -9
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/transforms.py +261 -76
- mindspore/dataset/utils/browse_dataset.py +9 -9
- mindspore/dataset/vision/__init__.py +8 -8
- mindspore/dataset/vision/c_transforms.py +10 -10
- mindspore/dataset/vision/py_transforms_util.py +3 -3
- mindspore/dataset/vision/transforms.py +2844 -549
- mindspore/dataset/vision/utils.py +161 -10
- mindspore/dataset/vision/validators.py +14 -2
- mindspore/dnnl.dll +0 -0
- mindspore/experimental/optim/__init__.py +12 -2
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +35 -34
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +40 -16
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +71 -127
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +15 -8
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +28 -19
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +339 -0
- mindspore/include/api/data_type.h +2 -2
- mindspore/include/api/dual_abi_helper.h +16 -3
- mindspore/include/api/model.h +4 -3
- mindspore/include/api/status.h +14 -0
- mindspore/include/c_api/model_c.h +173 -0
- mindspore/include/c_api/ms/base/types.h +1 -0
- mindspore/include/c_api/types_c.h +19 -0
- mindspore/include/dataset/execute.h +1 -3
- mindspore/include/dataset/vision.h +54 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +2 -2
- mindspore/mindrecord/__init__.py +5 -1
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +25 -0
- mindspore/mindrecord/filewriter.py +76 -58
- mindspore/mindrecord/mindpage.py +40 -6
- mindspore/mindrecord/shardutils.py +3 -2
- mindspore/mindrecord/shardwriter.py +7 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +53 -66
- mindspore/mindrecord/tools/cifar10_to_mr.py +48 -63
- mindspore/mindrecord/tools/csv_to_mr.py +7 -17
- mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
- mindspore/mindrecord/tools/mnist_to_mr.py +11 -21
- mindspore/mindrecord/tools/tfrecord_to_mr.py +2 -10
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/mint/__init__.py +1137 -0
- mindspore/{rewrite/ast_transformers → mint/linalg}/__init__.py +9 -4
- mindspore/mint/nn/__init__.py +512 -0
- mindspore/mint/nn/functional.py +573 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +185 -0
- mindspore/multiprocessing/__init__.py +72 -0
- mindspore/nn/__init__.py +1 -0
- mindspore/nn/cell.py +213 -257
- mindspore/nn/dynamic_lr.py +2 -2
- mindspore/nn/extend/__init__.py +29 -0
- mindspore/nn/extend/basic.py +140 -0
- mindspore/nn/extend/embedding.py +143 -0
- mindspore/{rewrite/ast_creator_register.py → nn/extend/layer/__init__.py} +9 -19
- mindspore/nn/extend/layer/normalization.py +109 -0
- mindspore/nn/extend/pooling.py +117 -0
- mindspore/nn/layer/activation.py +84 -94
- mindspore/nn/layer/basic.py +177 -82
- mindspore/nn/layer/channel_shuffle.py +3 -16
- mindspore/nn/layer/container.py +3 -3
- mindspore/nn/layer/conv.py +75 -66
- mindspore/nn/layer/embedding.py +103 -45
- mindspore/nn/layer/embedding_service.py +531 -0
- mindspore/nn/layer/embedding_service_layer.py +393 -0
- mindspore/nn/layer/image.py +4 -7
- mindspore/nn/layer/math.py +1 -1
- mindspore/nn/layer/normalization.py +52 -66
- mindspore/nn/layer/padding.py +30 -39
- mindspore/nn/layer/pooling.py +18 -9
- mindspore/nn/layer/rnn_cells.py +6 -16
- mindspore/nn/layer/rnns.py +6 -5
- mindspore/nn/layer/thor_layer.py +1 -2
- mindspore/nn/layer/timedistributed.py +1 -1
- mindspore/nn/layer/transformer.py +52 -50
- mindspore/nn/learning_rate_schedule.py +6 -5
- mindspore/nn/loss/loss.py +63 -84
- mindspore/nn/optim/ada_grad.py +6 -4
- mindspore/nn/optim/adadelta.py +3 -1
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +102 -181
- mindspore/nn/optim/adamax.py +4 -2
- mindspore/nn/optim/adasum.py +3 -3
- mindspore/nn/optim/asgd.py +4 -2
- mindspore/nn/optim/ftrl.py +31 -61
- mindspore/nn/optim/lamb.py +5 -3
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +6 -4
- mindspore/nn/optim/momentum.py +13 -25
- mindspore/nn/optim/optimizer.py +6 -3
- mindspore/nn/optim/proximal_ada_grad.py +4 -2
- mindspore/nn/optim/rmsprop.py +9 -3
- mindspore/nn/optim/rprop.py +4 -2
- mindspore/nn/optim/sgd.py +7 -4
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
- mindspore/nn/probability/distribution/beta.py +2 -2
- mindspore/nn/probability/distribution/categorical.py +4 -6
- mindspore/nn/probability/distribution/cauchy.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +2 -2
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +13 -1
- mindspore/nn/wrap/__init__.py +2 -1
- mindspore/nn/wrap/cell_wrapper.py +58 -13
- mindspore/nn/wrap/grad_reducer.py +148 -8
- mindspore/nn/wrap/loss_scale.py +32 -9
- mindspore/numpy/__init__.py +2 -0
- mindspore/numpy/array_creations.py +2 -0
- mindspore/numpy/array_ops.py +6 -6
- mindspore/numpy/dtypes.py +3 -3
- mindspore/numpy/fft.py +431 -0
- mindspore/numpy/math_ops.py +61 -67
- mindspore/numpy/utils.py +3 -0
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +8 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +4 -160
- mindspore/ops/_grad_experimental/grad_comm_ops.py +93 -36
- mindspore/ops/_grad_experimental/grad_inner_ops.py +8 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +92 -287
- mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
- mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
- mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/__init__.py +0 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +1 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
- mindspore/ops/_op_impl/{cpu/concat.py → aicpu/generate_eod_mask.py} +16 -17
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
- mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -3
- mindspore/ops/_op_impl/cpu/adam.py +2 -2
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
- mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
- mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
- mindspore/ops/_vmap/vmap_array_ops.py +164 -101
- mindspore/ops/_vmap/vmap_base.py +8 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +143 -58
- mindspore/ops/_vmap/vmap_image_ops.py +70 -13
- mindspore/ops/_vmap/vmap_math_ops.py +130 -58
- mindspore/ops/_vmap/vmap_nn_ops.py +249 -115
- mindspore/ops/_vmap/vmap_other_ops.py +1 -1
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +231 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +250 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +980 -0
- mindspore/ops/auto_generate/gen_ops_def.py +6443 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +13167 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +429 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +121 -23
- mindspore/ops/composite/math_ops.py +10 -49
- mindspore/ops/composite/multitype_ops/_compile_utils.py +191 -618
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +25 -134
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
- mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
- mindspore/ops/deprecated.py +14 -3
- mindspore/ops/extend/__init__.py +53 -0
- mindspore/ops/extend/array_func.py +218 -0
- mindspore/ops/extend/math_func.py +76 -0
- mindspore/ops/extend/nn_func.py +308 -0
- mindspore/ops/function/__init__.py +31 -11
- mindspore/ops/function/array_func.py +848 -1736
- mindspore/ops/function/clip_func.py +19 -31
- mindspore/ops/function/debug_func.py +2 -5
- mindspore/ops/function/fft_func.py +31 -0
- mindspore/ops/function/grad/grad_func.py +27 -20
- mindspore/ops/function/image_func.py +27 -21
- mindspore/ops/function/linalg_func.py +30 -53
- mindspore/ops/function/math_func.py +916 -2791
- mindspore/ops/function/nn_func.py +1445 -889
- mindspore/ops/function/other_func.py +6 -7
- mindspore/ops/function/parameter_func.py +6 -92
- mindspore/ops/function/random_func.py +254 -108
- mindspore/ops/function/reshard_func.py +102 -0
- mindspore/ops/function/sparse_func.py +4 -4
- mindspore/ops/function/sparse_unary_func.py +11 -18
- mindspore/ops/function/spectral_func.py +1 -1
- mindspore/ops/function/vmap_func.py +15 -14
- mindspore/ops/functional.py +342 -343
- mindspore/ops/op_info_register.py +16 -43
- mindspore/ops/operations/__init__.py +32 -23
- mindspore/ops/operations/_embedding_cache_ops.py +1 -1
- mindspore/ops/operations/_grad_ops.py +21 -853
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +155 -511
- mindspore/ops/operations/_quant_ops.py +4 -4
- mindspore/ops/operations/_rl_inner_ops.py +3 -3
- mindspore/ops/operations/_scalar_ops.py +5 -480
- mindspore/ops/operations/_sequence_ops.py +6 -36
- mindspore/ops/operations/_tensor_array.py +8 -8
- mindspore/ops/operations/array_ops.py +112 -2698
- mindspore/ops/operations/comm_ops.py +801 -118
- mindspore/ops/operations/custom_ops.py +62 -121
- mindspore/ops/operations/debug_ops.py +105 -36
- mindspore/ops/operations/image_ops.py +3 -219
- mindspore/ops/operations/inner_ops.py +54 -40
- mindspore/ops/operations/linalg_ops.py +1 -49
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +61 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2016 -0
- mindspore/ops/operations/math_ops.py +621 -4654
- mindspore/ops/operations/nn_ops.py +316 -2226
- mindspore/ops/operations/other_ops.py +53 -45
- mindspore/ops/operations/random_ops.py +4 -51
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/sparse_ops.py +8 -8
- mindspore/ops/primitive.py +204 -103
- mindspore/ops/silent_check.py +162 -0
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +250 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_ops.py +1084 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +968 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +138 -0
- mindspore/ops_generate/pyboost_utils.py +354 -0
- mindspore/ops_generate/template.py +239 -0
- mindspore/parallel/__init__.py +7 -4
- mindspore/parallel/_auto_parallel_context.py +155 -6
- mindspore/parallel/_cell_wrapper.py +16 -9
- mindspore/parallel/_cost_model_context.py +1 -1
- mindspore/parallel/_dp_allreduce_fusion.py +159 -159
- mindspore/parallel/_parallel_serialization.py +62 -14
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +18 -9
- mindspore/parallel/_transformer/__init__.py +1 -1
- mindspore/parallel/_transformer/layers.py +1 -1
- mindspore/parallel/_transformer/loss.py +1 -1
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/op_parallel_config.py +1 -1
- mindspore/parallel/_transformer/transformer.py +10 -10
- mindspore/parallel/_utils.py +161 -6
- mindspore/parallel/algo_parameter_config.py +6 -8
- mindspore/parallel/checkpoint_transform.py +369 -64
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +344 -0
- mindspore/parallel/cluster/process_entity/_utils.py +126 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +1 -1
- mindspore/parallel/mpi/_mpi_config.py +1 -1
- mindspore/parallel/parameter_broadcast.py +152 -0
- mindspore/parallel/shard.py +128 -17
- mindspore/profiler/__init__.py +3 -2
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/util.py +125 -0
- mindspore/profiler/envprofiling.py +2 -2
- mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
- mindspore/profiler/parser/ascend_analysis/constant.py +53 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +159 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +161 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +85 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +57 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +116 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +68 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
- mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
- mindspore/profiler/parser/ascend_flops_generator.py +27 -5
- mindspore/profiler/parser/ascend_fpbp_generator.py +8 -2
- mindspore/profiler/parser/ascend_hccl_generator.py +31 -280
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +151 -126
- mindspore/profiler/parser/ascend_msprof_generator.py +75 -274
- mindspore/profiler/parser/ascend_op_generator.py +94 -36
- mindspore/profiler/parser/ascend_timeline_generator.py +297 -131
- mindspore/profiler/parser/base_timeline_generator.py +17 -3
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -1
- mindspore/profiler/parser/framework_parser.py +11 -4
- mindspore/profiler/parser/integrator.py +3 -1
- mindspore/profiler/parser/memory_usage_parser.py +8 -2
- mindspore/profiler/parser/minddata_analyzer.py +8 -2
- mindspore/profiler/parser/minddata_parser.py +73 -4
- mindspore/profiler/parser/msadvisor_analyzer.py +5 -3
- mindspore/profiler/parser/msadvisor_parser.py +10 -4
- mindspore/profiler/parser/profiler_info.py +16 -1
- mindspore/profiler/profiling.py +522 -195
- mindspore/rewrite/__init__.py +2 -13
- mindspore/rewrite/api/node.py +123 -37
- mindspore/rewrite/api/pattern_engine.py +2 -3
- mindspore/rewrite/api/scoped_value.py +16 -15
- mindspore/rewrite/api/symbol_tree.py +46 -30
- mindspore/rewrite/ast_helpers/__init__.py +3 -6
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
- mindspore/rewrite/common/__init__.py +1 -2
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
- mindspore/rewrite/{namer.py → common/namer.py} +63 -18
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/node/__init__.py +5 -5
- mindspore/rewrite/node/call_function.py +23 -7
- mindspore/rewrite/node/cell_container.py +7 -3
- mindspore/rewrite/node/control_flow.py +53 -28
- mindspore/rewrite/node/node.py +212 -196
- mindspore/rewrite/node/node_manager.py +51 -22
- mindspore/rewrite/node/node_topological_manager.py +3 -23
- mindspore/rewrite/parsers/__init__.py +12 -0
- mindspore/rewrite/parsers/arguments_parser.py +8 -9
- mindspore/rewrite/parsers/assign_parser.py +637 -413
- mindspore/rewrite/parsers/attribute_parser.py +3 -4
- mindspore/rewrite/parsers/class_def_parser.py +115 -148
- mindspore/rewrite/parsers/constant_parser.py +5 -5
- mindspore/rewrite/parsers/container_parser.py +4 -6
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +31 -98
- mindspore/rewrite/parsers/function_def_parser.py +13 -5
- mindspore/rewrite/parsers/if_parser.py +28 -10
- mindspore/rewrite/parsers/module_parser.py +8 -182
- mindspore/rewrite/parsers/parser.py +1 -5
- mindspore/rewrite/parsers/parser_register.py +1 -1
- mindspore/rewrite/parsers/return_parser.py +5 -10
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +704 -185
- mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
- mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
- mindspore/run_check/_check_version.py +6 -14
- mindspore/run_check/run_check.py +1 -1
- mindspore/safeguard/rewrite_obfuscation.py +9 -19
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +6 -5
- mindspore/train/_utils.py +178 -4
- mindspore/train/amp.py +167 -245
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/__init__.py +5 -2
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +151 -37
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_flops_collector.py +238 -0
- mindspore/train/callback/_landscape.py +16 -11
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_mindio_ttp.py +443 -0
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
- mindspore/train/callback/_summary_collector.py +13 -14
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/data_sink.py +6 -5
- mindspore/train/dataset_helper.py +66 -21
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/accuracy.py +7 -7
- mindspore/train/metrics/confusion_matrix.py +8 -6
- mindspore/train/metrics/cosine_similarity.py +6 -4
- mindspore/train/metrics/error.py +2 -2
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/perplexity.py +2 -1
- mindspore/train/metrics/topk.py +2 -2
- mindspore/train/mind_ir_pb2.py +89 -15
- mindspore/train/model.py +298 -56
- mindspore/train/serialization.py +501 -221
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +1 -1
- mindspore/train/summary/summary_record.py +56 -34
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.2.11.dist-info → mindspore-2.3.0.dist-info}/METADATA +3 -3
- mindspore-2.3.0.dist-info/RECORD +1400 -0
- {mindspore-2.2.11.dist-info → mindspore-2.3.0.dist-info}/entry_points.txt +1 -0
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
- mindspore/gen_ops.py +0 -273
- mindspore/nn/layer/flash_attention.py +0 -189
- mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
- mindspore/ops/_op_impl/tbe/__init__.py +0 -47
- mindspore/ops/_op_impl/tbe/abs.py +0 -38
- mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/acos.py +0 -37
- mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/acosh.py +0 -37
- mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
- mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
- mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
- mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
- mindspore/ops/_op_impl/tbe/add.py +0 -42
- mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/add_n.py +0 -39
- mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
- mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
- mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
- mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
- mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
- mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
- mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
- mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/asin.py +0 -37
- mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/asinh.py +0 -37
- mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/assign.py +0 -79
- mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
- mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
- mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/atan.py +0 -37
- mindspore/ops/_op_impl/tbe/atan2.py +0 -38
- mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/atanh.py +0 -37
- mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
- mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
- mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
- mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
- mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
- mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
- mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
- mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
- mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
- mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
- mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cast.py +0 -55
- mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/cdist.py +0 -38
- mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/ceil.py +0 -37
- mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/celu.py +0 -39
- mindspore/ops/_op_impl/tbe/centralization.py +0 -39
- mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
- mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/concat.py +0 -40
- mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
- mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
- mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
- mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
- mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
- mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/cos.py +0 -37
- mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/cosh.py +0 -37
- mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
- mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cummin.py +0 -41
- mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
- mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
- mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
- mindspore/ops/_op_impl/tbe/diag.py +0 -38
- mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
- mindspore/ops/_op_impl/tbe/dilation.py +0 -40
- mindspore/ops/_op_impl/tbe/div.py +0 -41
- mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
- mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
- mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
- mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
- mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
- mindspore/ops/_op_impl/tbe/elu.py +0 -38
- mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/equal.py +0 -42
- mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/erf.py +0 -37
- mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfc.py +0 -37
- mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
- mindspore/ops/_op_impl/tbe/exp.py +0 -40
- mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
- mindspore/ops/_op_impl/tbe/expm1.py +0 -37
- mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
- mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/fill.py +0 -56
- mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/flatten.py +0 -48
- mindspore/ops/_op_impl/tbe/floor.py +0 -37
- mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
- mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
- mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
- mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
- mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
- mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
- mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
- mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/ger.py +0 -43
- mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/greater.py +0 -43
- mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
- mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
- mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
- mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
- mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
- mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
- mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
- mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/im2col.py +0 -42
- mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
- mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
- mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/inv.py +0 -38
- mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/invert.py +0 -37
- mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/iou.py +0 -38
- mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/is_close.py +0 -40
- mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
- mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
- mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
- mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
- mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
- mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
- mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/lerp.py +0 -38
- mindspore/ops/_op_impl/tbe/less.py +0 -41
- mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/log.py +0 -40
- mindspore/ops/_op_impl/tbe/log1p.py +0 -37
- mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
- mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
- mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
- mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
- mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/matmul.py +0 -53
- mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
- mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
- mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
- mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum.py +0 -39
- mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
- mindspore/ops/_op_impl/tbe/minimum.py +0 -40
- mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mish.py +0 -37
- mindspore/ops/_op_impl/tbe/mod.py +0 -41
- mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/mul.py +0 -37
- mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
- mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
- mindspore/ops/_op_impl/tbe/neg.py +0 -39
- mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
- mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
- mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
- mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
- mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
- mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/pack.py +0 -58
- mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
- mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
- mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/pdist.py +0 -36
- mindspore/ops/_op_impl/tbe/pooling.py +0 -46
- mindspore/ops/_op_impl/tbe/population_count.py +0 -38
- mindspore/ops/_op_impl/tbe/pow.py +0 -41
- mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/prelu.py +0 -37
- mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/range.py +0 -39
- mindspore/ops/_op_impl/tbe/real_div.py +0 -38
- mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
- mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
- mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
- mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
- mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6.py +0 -38
- mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/renorm.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
- mindspore/ops/_op_impl/tbe/rint.py +0 -37
- mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roll.py +0 -42
- mindspore/ops/_op_impl/tbe/round.py +0 -38
- mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
- mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
- mindspore/ops/_op_impl/tbe/select.py +0 -38
- mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/selu.py +0 -39
- mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sgd.py +0 -62
- mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sign.py +0 -38
- mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/sin.py +0 -37
- mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sinh.py +0 -37
- mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/slice.py +0 -58
- mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
- mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax.py +0 -37
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
- mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/softplus.py +0 -37
- mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softsign.py +0 -37
- mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sort.py +0 -38
- mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/split_d.py +0 -38
- mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/split_v.py +0 -39
- mindspore/ops/_op_impl/tbe/splitv.py +0 -39
- mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/square.py +0 -38
- mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
- mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
- mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
- mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
- mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
- mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
- mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
- mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
- mindspore/ops/_op_impl/tbe/sub.py +0 -39
- mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tan.py +0 -38
- mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh.py +0 -37
- mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
- mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
- mindspore/ops/_op_impl/tbe/tile.py +0 -37
- mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
- mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
- mindspore/ops/_op_impl/tbe/transpose.py +0 -60
- mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
- mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
- mindspore/ops/_op_impl/tbe/trunc.py +0 -39
- mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/unpack.py +0 -38
- mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
- mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
- mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
- mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
- mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
- mindspore/ops/_tracefunc.py +0 -241
- mindspore/ops/arg_dtype_cast.py +0 -54
- mindspore/rewrite/api/tree_node_helper.py +0 -60
- mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
- mindspore/rewrite/namespace.py +0 -53
- mindspore-2.2.11.dist-info/RECORD +0 -1920
- {mindspore-2.2.11.dist-info → mindspore-2.3.0.dist-info}/WHEEL +0 -0
- {mindspore-2.2.11.dist-info → mindspore-2.3.0.dist-info}/top_level.txt +0 -0
|
@@ -24,13 +24,13 @@ from mindspore.ops import operations as P
|
|
|
24
24
|
from mindspore.ops.composite import base
|
|
25
25
|
from mindspore.ops._primitive_cache import _get_cache_prim
|
|
26
26
|
from mindspore.ops.operations._inner_ops import TensorCopySlices, SliceGetItem, \
|
|
27
|
-
TopTypeof,
|
|
27
|
+
TopTypeof, IsParameter, GetitemTensorIndexInfo, SetitemTensorIndexInfo, \
|
|
28
28
|
SelectView, CopyWithSlice
|
|
29
|
+
from mindspore.ops.operations._sequence_ops import TensorToTuple, TensorToScalar, TupleToTensor
|
|
29
30
|
from mindspore.common import dtype as mstype
|
|
30
31
|
from mindspore.common._register_for_tensor import tensor_operator_registry
|
|
31
32
|
from mindspore.common.initializer import Zero
|
|
32
|
-
from mindspore.common import Tensor, CSRTensor, COOTensor
|
|
33
|
-
from mindspore.common import mutable
|
|
33
|
+
from mindspore.common import Tensor, CSRTensor, COOTensor, mutable
|
|
34
34
|
from mindspore import ops
|
|
35
35
|
from mindspore.ops.primitive import _primexpr
|
|
36
36
|
from mindspore import _checkparam as validator
|
|
@@ -137,7 +137,8 @@ def data_update_by_ops(transfer_type, arg, data, new_index, origin_data, value=N
|
|
|
137
137
|
elif transfer_type == ValueTransferType.kGatherND:
|
|
138
138
|
if isinstance(new_index, list):
|
|
139
139
|
new_index = handle_multi_dim_index_tensor(new_index, arg)
|
|
140
|
-
|
|
140
|
+
new_index = format_index_tensor(new_index, (None, F.shape(data)[:F.shape(new_index)[-1]]))
|
|
141
|
+
data = F.gather_nd(data, new_index)
|
|
141
142
|
elif transfer_type == ValueTransferType.kTensorScatterUpdate:
|
|
142
143
|
if isinstance(new_index, list):
|
|
143
144
|
new_index = handle_multi_dim_index_tensor(new_index, arg)
|
|
@@ -217,8 +218,8 @@ def _tensor_setitem(self, index, value):
|
|
|
217
218
|
return output
|
|
218
219
|
|
|
219
220
|
|
|
220
|
-
tensor_operator_registry
|
|
221
|
-
tensor_operator_registry
|
|
221
|
+
setattr(tensor_operator_registry, "__getitem__", _tensor_getitem)
|
|
222
|
+
setattr(tensor_operator_registry, "__setitem__", _tensor_setitem)
|
|
222
223
|
|
|
223
224
|
|
|
224
225
|
def _tensor_add(self, other):
|
|
@@ -287,15 +288,15 @@ def _tensor_floordiv(self, other):
|
|
|
287
288
|
return F.floordiv(self, other)
|
|
288
289
|
|
|
289
290
|
|
|
290
|
-
tensor_operator_registry
|
|
291
|
-
tensor_operator_registry
|
|
292
|
-
tensor_operator_registry
|
|
293
|
-
tensor_operator_registry
|
|
294
|
-
tensor_operator_registry
|
|
295
|
-
tensor_operator_registry
|
|
296
|
-
tensor_operator_registry
|
|
297
|
-
tensor_operator_registry
|
|
298
|
-
tensor_operator_registry
|
|
291
|
+
setattr(tensor_operator_registry, '__add__', _tensor_add)
|
|
292
|
+
setattr(tensor_operator_registry, '__sub__', _tensor_sub)
|
|
293
|
+
setattr(tensor_operator_registry, '__mul__', _tensor_mul)
|
|
294
|
+
setattr(tensor_operator_registry, '__matmul__', _tensor_matmul)
|
|
295
|
+
setattr(tensor_operator_registry, '__truediv__', _tensor_div)
|
|
296
|
+
setattr(tensor_operator_registry, '__mod__', _tensor_mod)
|
|
297
|
+
setattr(tensor_operator_registry, '__pow__', _tensor_pow)
|
|
298
|
+
setattr(tensor_operator_registry, '__rpow__', _tensor_rpow)
|
|
299
|
+
setattr(tensor_operator_registry, '__floordiv__', _tensor_floordiv)
|
|
299
300
|
|
|
300
301
|
|
|
301
302
|
def _scalar_to_tensor(input_x):
|
|
@@ -317,24 +318,25 @@ def tensor_item(data, *args):
|
|
|
317
318
|
# transform a.item(tuple(int)) -> a.item(int1,int2...intN)
|
|
318
319
|
if data.ndim == 0:
|
|
319
320
|
_check_scalar_tensor_args(args)
|
|
320
|
-
return
|
|
321
|
+
return TensorToScalar()(data)
|
|
321
322
|
if len(args) == 1 and isinstance(args[0], tuple):
|
|
322
323
|
args = args[0]
|
|
323
324
|
|
|
324
325
|
args_types = hyper_map(F.typeof, args)
|
|
325
326
|
if not args or const_utils.judge_index_type(args_types[0], mstype.type_none):
|
|
326
327
|
if data.shape == (1,):
|
|
327
|
-
return
|
|
328
|
+
return TensorToScalar()(data[0])
|
|
328
329
|
const_utils.raise_value_error("Can only convert an array of size 1 to a Python scalar")
|
|
329
330
|
|
|
330
331
|
if not const_utils.judge_indexes_types(args_types, mstype.int64):
|
|
331
332
|
const_utils.raise_type_error("The index object cannot be interpreted as an integer")
|
|
332
333
|
|
|
333
334
|
if len(args) == data.ndim:
|
|
334
|
-
return
|
|
335
|
+
return tensor_index_by_tuple(data, args)
|
|
335
336
|
if len(args) > 1:
|
|
336
337
|
const_utils.raise_value_error("Incorrect number of indices for array")
|
|
337
|
-
|
|
338
|
+
output = _tensor_index_by_integer(F.reshape(data, (-1,)), args[0])
|
|
339
|
+
return TensorToScalar()(output)
|
|
338
340
|
|
|
339
341
|
|
|
340
342
|
def tensor_itemset(data, *args):
|
|
@@ -354,8 +356,8 @@ def tensor_itemset(data, *args):
|
|
|
354
356
|
return tensor_itemset_with_number(data, args[0])
|
|
355
357
|
|
|
356
358
|
|
|
357
|
-
tensor_operator_registry
|
|
358
|
-
tensor_operator_registry
|
|
359
|
+
setattr(tensor_operator_registry, "item", tensor_item)
|
|
360
|
+
setattr(tensor_operator_registry, "itemset", tensor_itemset)
|
|
359
361
|
|
|
360
362
|
|
|
361
363
|
def tensor_itemset_with_number(data, number_value):
|
|
@@ -521,24 +523,45 @@ def _expand_data_dims(data, tuple_index):
|
|
|
521
523
|
return data, tuple_index_new
|
|
522
524
|
|
|
523
525
|
|
|
524
|
-
def
|
|
525
|
-
"""convert
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
526
|
+
def _convert_list_index_to_tensor(list_index):
|
|
527
|
+
"""convert list to tensor"""
|
|
528
|
+
has_bool = False
|
|
529
|
+
has_int = False
|
|
530
|
+
has_no_bool_int = False
|
|
531
|
+
for idx in list_index:
|
|
532
|
+
if isinstance(idx, bool):
|
|
533
|
+
has_bool = True
|
|
534
|
+
elif isinstance(idx, int):
|
|
535
|
+
has_int = True
|
|
536
|
+
else:
|
|
537
|
+
has_no_bool_int = True
|
|
538
|
+
|
|
539
|
+
all_bool = has_bool and not has_int and not has_no_bool_int
|
|
540
|
+
all_int = has_int and not has_bool and not has_no_bool_int
|
|
541
|
+
all_bool_or_int = not has_no_bool_int
|
|
542
|
+
|
|
543
|
+
if all_int:
|
|
544
|
+
index_tensor = TupleToTensor()(tuple(list_index), mstype.int64)
|
|
545
|
+
return index_tensor
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
if all_bool:
|
|
549
|
+
index_tensor = TupleToTensor()(tuple(list_index), mstype.bool_)
|
|
550
|
+
return index_tensor
|
|
551
|
+
|
|
552
|
+
# convert bool to int if index is mixture of (bool, int)
|
|
553
|
+
if all_bool_or_int:
|
|
554
|
+
new_index = []
|
|
555
|
+
for idx in list_index:
|
|
556
|
+
if isinstance(idx, bool):
|
|
557
|
+
new_idx = int(idx)
|
|
558
|
+
new_index.append(new_idx)
|
|
559
|
+
else:
|
|
560
|
+
new_index.append(idx)
|
|
561
|
+
index_tensor = TupleToTensor()(tuple(new_index), mstype.int64)
|
|
562
|
+
return index_tensor
|
|
563
|
+
|
|
564
|
+
return None
|
|
542
565
|
|
|
543
566
|
|
|
544
567
|
class _TensorIndexGetitem(base.TensorIndexGetitem_):
|
|
@@ -564,26 +587,6 @@ def tensor_index_by_slice(data, slice_index):
|
|
|
564
587
|
return _tensor_index_getitem(data, slice_index)
|
|
565
588
|
|
|
566
589
|
|
|
567
|
-
def get_stride_info_from_slice(data, slice_index):
|
|
568
|
-
"""get the stride info from slice index"""
|
|
569
|
-
data_shape = F.dyn_shape(data)
|
|
570
|
-
begin_strides, end_strides, step_strides = [], [], []
|
|
571
|
-
start, stop, step = get_slice_stride(slice_index, data_shape[0])
|
|
572
|
-
if start.ndim > 0:
|
|
573
|
-
start = start.item()
|
|
574
|
-
if stop.ndim > 0:
|
|
575
|
-
stop = stop.item()
|
|
576
|
-
if step.ndim > 0:
|
|
577
|
-
step = step.item()
|
|
578
|
-
begin_strides.append(start)
|
|
579
|
-
end_strides.append(stop)
|
|
580
|
-
step_strides.append(step)
|
|
581
|
-
begin_tensor = stack(begin_strides)
|
|
582
|
-
end_tensor = stack(end_strides)
|
|
583
|
-
step_tensor = stack(step_strides)
|
|
584
|
-
return begin_tensor, end_tensor, step_tensor
|
|
585
|
-
|
|
586
|
-
|
|
587
590
|
def tensor_index_by_number(data, number_index):
|
|
588
591
|
"""Tensor getitem by a Number which may be integer/float/bool value"""
|
|
589
592
|
if isinstance(number_index, bool):
|
|
@@ -607,31 +610,18 @@ def _tensor_index_by_bool(data, bool_value):
|
|
|
607
610
|
return output
|
|
608
611
|
|
|
609
612
|
|
|
610
|
-
def get_stride_info_from_integer(
|
|
613
|
+
def get_stride_info_from_integer(int_index):
|
|
611
614
|
"""Convert integer to slice"""
|
|
612
|
-
begin_strides =
|
|
613
|
-
end_strides =
|
|
614
|
-
step_strides =
|
|
615
|
-
|
|
616
|
-
end_tensor = stack(end_strides)
|
|
617
|
-
step_tensor = stack(step_strides)
|
|
618
|
-
return begin_tensor, end_tensor, step_tensor
|
|
615
|
+
begin_strides = (int_index,)
|
|
616
|
+
end_strides = (int_index + 1,)
|
|
617
|
+
step_strides = (1,)
|
|
618
|
+
return begin_strides, end_strides, step_strides
|
|
619
619
|
|
|
620
620
|
|
|
621
621
|
def _tensor_index_by_integer(data, int_index):
|
|
622
622
|
"""Tensor getitem by a single integer number"""
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
tensor_index = _scalar_to_tensor(int_index)
|
|
626
|
-
begin_strides, end_strides, step_strides = get_stride_info_from_integer(tensor_index)
|
|
627
|
-
else:
|
|
628
|
-
if not data_shape:
|
|
629
|
-
const_utils.raise_type_error("Cannot iterate over a scalar tensor.")
|
|
630
|
-
if data.ndim < 1 or data.ndim > 8:
|
|
631
|
-
const_utils.raise_value_error("Expect Tensor to have dimension between 1 and 8.")
|
|
632
|
-
transformed_number = const_utils.check_range(int_index, data_shape[0])
|
|
633
|
-
begin_strides, end_strides, step_strides = \
|
|
634
|
-
const_utils.get_stride_info_from_integer(data_shape, transformed_number)
|
|
623
|
+
begin_strides, end_strides, step_strides = get_stride_info_from_integer(int_index)
|
|
624
|
+
|
|
635
625
|
shrink_axis_mask = 1
|
|
636
626
|
begin_mask = 0
|
|
637
627
|
end_mask = 0
|
|
@@ -664,6 +654,7 @@ def tensor_index_by_tensor(data, tensor_index):
|
|
|
664
654
|
if not F.is_sequence_value_unknown(F.shape(data)):
|
|
665
655
|
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
666
656
|
if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Int):
|
|
657
|
+
tensor_index = F.select(tensor_index < 0, tensor_index + F.shape(data)[0], tensor_index)
|
|
667
658
|
return F.gather(data, tensor_index, 0)
|
|
668
659
|
if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Bool):
|
|
669
660
|
return tensor_index_by_bool_tensor(data, tensor_index)
|
|
@@ -676,27 +667,23 @@ def tensor_index_by_tensor(data, tensor_index):
|
|
|
676
667
|
def tensor_index_by_list(data, list_index):
|
|
677
668
|
"""Tensor getitem by list of int and bool"""
|
|
678
669
|
min_data_dim, max_data_dim = 1, 8
|
|
679
|
-
|
|
670
|
+
if F.isconstant(data.ndim):
|
|
671
|
+
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
680
672
|
|
|
681
673
|
data_shape = F.shape(data)
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
list_index, data_shape[0])
|
|
696
|
-
if tensor_index is False:
|
|
697
|
-
const_utils.raise_index_error(
|
|
698
|
-
"When tensor is indexed by list, the list can't be empty.")
|
|
699
|
-
return F.gather(data, tensor_index, 0)
|
|
674
|
+
if F.isconstant(data_shape[0]) and all(isinstance(i, bool) for i in list_index):
|
|
675
|
+
if data_shape[0] != len(list_index):
|
|
676
|
+
raise IndexError(
|
|
677
|
+
f'dimension is {data_shape[0]} but corresponding boolean dimension is {len(list_index)}')
|
|
678
|
+
tensor_index = Tensor(list_index).nonzero()
|
|
679
|
+
return F.gather_nd(data, tensor_index)
|
|
680
|
+
|
|
681
|
+
if not list_index:
|
|
682
|
+
const_utils.raise_index_error("When tensor is indexed by list, the list can't be empty.")
|
|
683
|
+
|
|
684
|
+
index_tensor = _convert_list_index_to_tensor(list_index)
|
|
685
|
+
if index_tensor is not None:
|
|
686
|
+
return tensor_index_by_tensor(data, index_tensor)
|
|
700
687
|
|
|
701
688
|
tuple_index_new = ()
|
|
702
689
|
for index in list_index:
|
|
@@ -704,16 +691,6 @@ def tensor_index_by_list(data, list_index):
|
|
|
704
691
|
return tensor_index_by_tuple(data, tuple_index_new)
|
|
705
692
|
|
|
706
693
|
|
|
707
|
-
def convert_tupleslice_to_tensor(tuple_index):
|
|
708
|
-
"""convert mutable scalar in slice to tensor"""
|
|
709
|
-
new_tuple_index = []
|
|
710
|
-
for item in tuple_index:
|
|
711
|
-
if isinstance(item, slice):
|
|
712
|
-
item = convert_variable_to_tensor_slice(item)
|
|
713
|
-
new_tuple_index.append(item)
|
|
714
|
-
return tuple(new_tuple_index)
|
|
715
|
-
|
|
716
|
-
|
|
717
694
|
def judge_tuple_index_dim_check_error(index_dim, data_dim):
|
|
718
695
|
"""raise IndexError when tuple_index's dim is invalid"""
|
|
719
696
|
if index_dim > data_dim:
|
|
@@ -721,29 +698,6 @@ def judge_tuple_index_dim_check_error(index_dim, data_dim):
|
|
|
721
698
|
f"dim of index:{index_dim}, dim of data:{data_dim}")
|
|
722
699
|
|
|
723
700
|
|
|
724
|
-
class _HandleEmptySlice(base.HandleEmptySlice_):
|
|
725
|
-
"""
|
|
726
|
-
Getting item of Tensor.
|
|
727
|
-
|
|
728
|
-
Args:
|
|
729
|
-
data (Tensor): A tuple to be sliced.
|
|
730
|
-
index: Index of tensor.
|
|
731
|
-
|
|
732
|
-
Returns:
|
|
733
|
-
Type is the same as the element type of data.
|
|
734
|
-
"""
|
|
735
|
-
|
|
736
|
-
def __init__(self, name):
|
|
737
|
-
"""Initialize _HandleEmptySlice."""
|
|
738
|
-
base.HandleEmptySlice_.__init__(self, name)
|
|
739
|
-
|
|
740
|
-
def __call__(self, *args):
|
|
741
|
-
pass
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
_handle_empty_slice = _HandleEmptySlice('handle_zero_tuple_index')
|
|
745
|
-
|
|
746
|
-
|
|
747
701
|
def judge_tuple_index_dim(data, tuple_index):
|
|
748
702
|
"""Judge whether tuple_index's dim is valid"""
|
|
749
703
|
data_dim = data.ndim
|
|
@@ -756,50 +710,20 @@ def judge_tuple_index_dim(data, tuple_index):
|
|
|
756
710
|
judge_tuple_index_dim_check_error(index_dim, data_dim)
|
|
757
711
|
|
|
758
712
|
|
|
759
|
-
def judge_simple_tuple_index(data, tuple_index):
|
|
760
|
-
"""Judge whether tuple_index is simple index, which not rollback to cpu ops."""
|
|
761
|
-
op_name = const_utils.TENSOR_GETITEM
|
|
762
|
-
indexes_types = hyper_map(toptypeof, tuple_index)
|
|
763
|
-
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
|
|
764
|
-
return F.isconstant(tuple_index) and contain_type == const_utils.ALL_BASIC \
|
|
765
|
-
and F.is_sequence_value_unknown(F.shape(data)) and F.isconstant(F.rank(data))
|
|
766
|
-
|
|
767
|
-
|
|
768
713
|
def tensor_index_by_tuple(data, tuple_index):
|
|
769
714
|
"""Tensor getitem by tuple of various types with None"""
|
|
770
715
|
if not tuple_index:
|
|
771
716
|
return data
|
|
772
|
-
if judge_simple_tuple_index(data, tuple_index):
|
|
773
|
-
tuple_index = convert_tupleslice_to_tensor(tuple_index)
|
|
774
|
-
op_name = const_utils.TENSOR_GETITEM
|
|
775
|
-
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
|
776
|
-
min_data_dim, max_data_dim = 1, 8
|
|
777
|
-
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
778
|
-
return _tensor_getitem_by_tuple_slice(data, tuple_index)
|
|
779
717
|
|
|
780
718
|
if not F.is_sequence_value_unknown(F.shape(data)):
|
|
781
719
|
judge_tuple_index_dim(data, tuple_index)
|
|
782
720
|
tuple_index, zero_index, non_zero_shapes = _handle_bool_tensor(tuple_index)
|
|
783
721
|
for non_zero_shape in non_zero_shapes:
|
|
784
|
-
if
|
|
722
|
+
if 0 in non_zero_shape:
|
|
785
723
|
tuple_index = zero_index
|
|
786
724
|
break
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
if 0 in stub_zero_dim_tensor.shape:
|
|
790
|
-
return F.fill(data.dtype, stub_zero_dim_tensor.shape, 0)
|
|
791
|
-
has_tensor_index = False
|
|
792
|
-
for i in tuple_index:
|
|
793
|
-
if isinstance(i, Tensor):
|
|
794
|
-
has_tensor_index = True
|
|
795
|
-
break
|
|
796
|
-
empty_broadcast_data_shape = False
|
|
797
|
-
_broadcast_data_shape = _handle_scalar_tensor_index(data, tuple_index)
|
|
798
|
-
if has_tensor_index and isinstance(_broadcast_data_shape, Tensor) and _broadcast_data_shape == Tensor([0]):
|
|
799
|
-
empty_broadcast_data_shape = True
|
|
800
|
-
if has_tensor_index and isinstance(_broadcast_data_shape, tuple) and not _broadcast_data_shape:
|
|
801
|
-
empty_broadcast_data_shape = True
|
|
802
|
-
return _tensor_index_getitem(data, tuple_index, empty_broadcast_data_shape)
|
|
725
|
+
|
|
726
|
+
return _tensor_index_getitem(data, tuple_index)
|
|
803
727
|
|
|
804
728
|
|
|
805
729
|
def get_slice_stride(slice_index, dim_size):
|
|
@@ -809,20 +733,20 @@ def get_slice_stride(slice_index, dim_size):
|
|
|
809
733
|
step = slice_get_item(slice_index, "step")
|
|
810
734
|
|
|
811
735
|
if start is None:
|
|
812
|
-
start =
|
|
736
|
+
start = 0
|
|
813
737
|
if stop is None:
|
|
814
738
|
stop = dim_size
|
|
815
739
|
if step is None:
|
|
816
|
-
step =
|
|
740
|
+
step = 1
|
|
817
741
|
|
|
818
|
-
if
|
|
819
|
-
start =
|
|
742
|
+
if isinstance(start, Tensor):
|
|
743
|
+
start = int(start)
|
|
820
744
|
|
|
821
|
-
if
|
|
822
|
-
stop =
|
|
745
|
+
if isinstance(stop, Tensor):
|
|
746
|
+
stop = int(stop)
|
|
823
747
|
|
|
824
|
-
if
|
|
825
|
-
step =
|
|
748
|
+
if isinstance(step, Tensor):
|
|
749
|
+
step = int(step)
|
|
826
750
|
|
|
827
751
|
return start, stop, step
|
|
828
752
|
|
|
@@ -841,190 +765,6 @@ def cal_tuple_slice_mask(data_shape, tuple_index):
|
|
|
841
765
|
return begin_mask, end_mask
|
|
842
766
|
|
|
843
767
|
|
|
844
|
-
def _get_stride_info_from_tuple(data, tuple_index):
|
|
845
|
-
"""get the stride info from tuple"""
|
|
846
|
-
data_shape = F.dyn_shape(data)
|
|
847
|
-
begin_strides, end_strides, step_strides = [], [], []
|
|
848
|
-
tuple_index_len = len(tuple_index)
|
|
849
|
-
data_dim = data.ndim
|
|
850
|
-
shrink_axis, index_count, ellipsis_count = 0, 0, 0
|
|
851
|
-
for item in range(data_dim):
|
|
852
|
-
if item >= tuple_index_len or item >= data_dim:
|
|
853
|
-
break
|
|
854
|
-
index = tuple_index[item]
|
|
855
|
-
dim_size = data_shape[item]
|
|
856
|
-
if isinstance(index, slice):
|
|
857
|
-
start, stop, step = get_slice_stride(index, dim_size)
|
|
858
|
-
begin_strides.append(start)
|
|
859
|
-
end_strides.append(stop)
|
|
860
|
-
step_strides.append(step)
|
|
861
|
-
index_count = index_count + 1
|
|
862
|
-
elif isinstance(index, int):
|
|
863
|
-
int_tensor = _scalar_to_tensor(index)
|
|
864
|
-
begin_strides.append(int_tensor)
|
|
865
|
-
end_strides.append(int_tensor + const_utils.make_tensor(1))
|
|
866
|
-
step_strides.append(const_utils.make_tensor(1))
|
|
867
|
-
shrink_axis = shrink_axis + (2 ** index_count)
|
|
868
|
-
index_count = index_count + 1
|
|
869
|
-
elif index is ...:
|
|
870
|
-
ellipsis_count = ellipsis_count + 1
|
|
871
|
-
if ellipsis_count > 1:
|
|
872
|
-
const_utils.raise_value_error("An index can have only one ellipsis (...)")
|
|
873
|
-
ellipsis_range_size = data_dim - tuple_index_len + 1
|
|
874
|
-
begin_strides.extend([const_utils.make_tensor(0)] * ellipsis_range_size)
|
|
875
|
-
end_strides.extend(
|
|
876
|
-
[shape for shape in data_shape[index_count: index_count + ellipsis_range_size]])
|
|
877
|
-
step_strides.extend([const_utils.make_tensor(1)] * ellipsis_range_size)
|
|
878
|
-
index_count = index_count + ellipsis_range_size
|
|
879
|
-
else:
|
|
880
|
-
exp_msg = const_utils.gen_exception_msg("Not supported index data type, got {}, type is {}", index,
|
|
881
|
-
type(index))
|
|
882
|
-
const_utils.raise_index_error(exp_msg)
|
|
883
|
-
begin_tensor = stack(begin_strides)
|
|
884
|
-
end_tensor = stack(end_strides)
|
|
885
|
-
step_tensor = stack(step_strides)
|
|
886
|
-
strides_v = {
|
|
887
|
-
'begin': begin_tensor,
|
|
888
|
-
'end': end_tensor,
|
|
889
|
-
'step': step_tensor
|
|
890
|
-
}
|
|
891
|
-
return strides_v, shrink_axis
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
def _tensor_getitem_by_tuple_slice(data, tuple_index):
|
|
895
|
-
"""Tensor getitem by a tuple of slice"""
|
|
896
|
-
data_shape = F.shape(data)
|
|
897
|
-
is_dynamic = F.is_sequence_value_unknown(data_shape)
|
|
898
|
-
for item in tuple_index:
|
|
899
|
-
if isinstance(item, slice):
|
|
900
|
-
is_dynamic = is_dynamic or isinstance(slice_get_item(item, "start"), Tensor) \
|
|
901
|
-
or isinstance(slice_get_item(item, "stop"), Tensor) \
|
|
902
|
-
or isinstance(slice_get_item(item, "step"), Tensor)
|
|
903
|
-
|
|
904
|
-
strides_v = {}
|
|
905
|
-
shrink_axis_mask = 0
|
|
906
|
-
if not is_dynamic:
|
|
907
|
-
strides_v, shrink_axis_mask = const_utils.get_stride_info_from_tuple(
|
|
908
|
-
data_shape, tuple_index)
|
|
909
|
-
else:
|
|
910
|
-
strides_v, shrink_axis_mask = _get_stride_info_from_tuple(
|
|
911
|
-
data, tuple_index)
|
|
912
|
-
begin_mask, end_mask = cal_tuple_slice_mask(data_shape, tuple_index)
|
|
913
|
-
begin_v = strides_v['begin']
|
|
914
|
-
end_v = strides_v['end']
|
|
915
|
-
step_v = strides_v['step']
|
|
916
|
-
return strided_slice(data, begin_v, end_v, step_v, begin_mask, end_mask, 0, 0, shrink_axis_mask)
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
@_primexpr
|
|
920
|
-
def _tensor_getitem_by_tuple_parse_bool_tensor_index(index, tuple_index_new, tensor_indexes,
|
|
921
|
-
tensor_positions_new):
|
|
922
|
-
""" parse index of bool tensor type """
|
|
923
|
-
indices = index.nonzero()
|
|
924
|
-
if indices.shape[0] == 0:
|
|
925
|
-
return None, tensor_indexes, tensor_positions_new
|
|
926
|
-
indices = F.cast(indices, mstype.int64)
|
|
927
|
-
indices = indices.T
|
|
928
|
-
for sub_index in indices:
|
|
929
|
-
tensor_positions_new.append(len(tuple_index_new))
|
|
930
|
-
tuple_index_new += (sub_index,)
|
|
931
|
-
tensor_indexes.append(sub_index)
|
|
932
|
-
return tuple_index_new, tensor_indexes, tensor_positions_new
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
def _tensor_getitem_by_tuple_parse_tensor_index(index, tuple_index_new, tensor_indexes, tensor_positions_new):
|
|
936
|
-
""" parse index of tensor type """
|
|
937
|
-
if F.dtype(index) in mstype.int_type:
|
|
938
|
-
tensor_index = F.cast(index, mstype.int64)
|
|
939
|
-
tensor_positions_new.append(len(tuple_index_new))
|
|
940
|
-
tuple_index_new += (tensor_index,)
|
|
941
|
-
tensor_indexes.append(tensor_index)
|
|
942
|
-
elif F.dtype(index) == mstype.bool_:
|
|
943
|
-
return _tensor_getitem_by_tuple_parse_bool_tensor_index(index, tuple_index_new, tensor_indexes,
|
|
944
|
-
tensor_positions_new)
|
|
945
|
-
else:
|
|
946
|
-
exp_msg = const_utils.gen_exception_msg(
|
|
947
|
-
"The tensor element in tuple index must be int or bool type, but got {}.", F.dtype(index))
|
|
948
|
-
const_utils.raise_index_error(exp_msg)
|
|
949
|
-
return tuple_index_new, tensor_indexes, tensor_positions_new
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
def _tensor_getitem_by_tuple(data, tuple_index, op_name):
|
|
953
|
-
"""Tensor getitem by a tuple of mixed tensor."""
|
|
954
|
-
slice_is_tensor = False
|
|
955
|
-
for item in tuple_index:
|
|
956
|
-
if isinstance(item, slice):
|
|
957
|
-
slice_is_tensor = isinstance(slice_get_item(item, "start"), Tensor) \
|
|
958
|
-
or isinstance(slice_get_item(item, "stop"), Tensor) \
|
|
959
|
-
or isinstance(slice_get_item(item, "step"), Tensor)
|
|
960
|
-
if slice_is_tensor:
|
|
961
|
-
const_utils.raise_index_error("Not supported when slice has tensor")
|
|
962
|
-
|
|
963
|
-
indexes_types = hyper_map(toptypeof, tuple_index)
|
|
964
|
-
slice_positions, _, _, int_positions, _, tensor_positions, sequence_positions = \
|
|
965
|
-
const_utils.get_pos_of_indexes_types(indexes_types, op_name)
|
|
966
|
-
data_shape = F.shape(data)
|
|
967
|
-
tensor_indexes, slice_indexes = [], []
|
|
968
|
-
tuple_index_new, slice_shapes = (), ()
|
|
969
|
-
slice_positions_new, tensor_positions_new = [], []
|
|
970
|
-
for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
|
|
971
|
-
if i in int_positions:
|
|
972
|
-
int_index = const_utils.check_range(index, dim_size)
|
|
973
|
-
tensor_index = F.scalar_to_tensor(int_index, mstype.int64)
|
|
974
|
-
if F.is_sequence_value_unknown(data_shape):
|
|
975
|
-
tensor_index = _scalar_to_tensor(int_index)
|
|
976
|
-
tensor_index = F.cast(tensor_index, mstype.int64)
|
|
977
|
-
tensor_positions_new.append(len(tuple_index_new))
|
|
978
|
-
tuple_index_new += (tensor_index,)
|
|
979
|
-
tensor_indexes.append(tensor_index)
|
|
980
|
-
elif i in sequence_positions:
|
|
981
|
-
tensor_index = const_utils.sequence_to_index(index, dim_size)
|
|
982
|
-
if tensor_index is False:
|
|
983
|
-
const_utils.raise_index_error("The sequence element(tuple/list) in tuple index can't be empty.")
|
|
984
|
-
tensor_positions_new.append(len(tuple_index_new))
|
|
985
|
-
tuple_index_new += (tensor_index,)
|
|
986
|
-
tensor_indexes.append(tensor_index)
|
|
987
|
-
elif i in tensor_positions:
|
|
988
|
-
tuple_index_new, tensor_indexes, tensor_positions_new = \
|
|
989
|
-
_tensor_getitem_by_tuple_parse_tensor_index(index, tuple_index_new,
|
|
990
|
-
tensor_indexes, tensor_positions_new)
|
|
991
|
-
if tuple_index_new is None:
|
|
992
|
-
return Tensor([])
|
|
993
|
-
elif i in slice_positions:
|
|
994
|
-
slice_ele_list_index = const_utils.transform_slice_to_ele_list(index, dim_size)
|
|
995
|
-
slice_shapes += (len(slice_ele_list_index),)
|
|
996
|
-
slice_positions_new.append(len(tuple_index_new))
|
|
997
|
-
tuple_index_new += (slice_ele_list_index,)
|
|
998
|
-
slice_indexes.append(slice_ele_list_index)
|
|
999
|
-
tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
|
|
1000
|
-
broadcast_shape, index_tensor_new_shape, final_shape, fancy_position = \
|
|
1001
|
-
const_utils.generate_index_info_from_tuple_of_mixed_tensors(tensor_positions_new, tensor_indexes_shapes,
|
|
1002
|
-
slice_shapes, op_name)
|
|
1003
|
-
|
|
1004
|
-
tuple_index_len = len(tuple_index)
|
|
1005
|
-
if 0 in final_shape + data_shape:
|
|
1006
|
-
if tuple_index_len < len(data_shape):
|
|
1007
|
-
final_shape = final_shape + data_shape[tuple_index_len:]
|
|
1008
|
-
return const_utils.make_tensor([], data.dtype, final_shape)
|
|
1009
|
-
|
|
1010
|
-
final_index_tensors = []
|
|
1011
|
-
slice_cnt = 0
|
|
1012
|
-
for i, index in enumerate(tuple_index_new):
|
|
1013
|
-
if i in tensor_positions_new:
|
|
1014
|
-
transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape,
|
|
1015
|
-
index)
|
|
1016
|
-
final_index_tensors.append(transform_tensor)
|
|
1017
|
-
elif i in slice_positions_new:
|
|
1018
|
-
slice_index_tensor = convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape,
|
|
1019
|
-
slice_shapes, fancy_position)
|
|
1020
|
-
final_index_tensors.append(slice_index_tensor)
|
|
1021
|
-
slice_cnt += 1
|
|
1022
|
-
|
|
1023
|
-
indices = stack(final_index_tensors)
|
|
1024
|
-
result = F.gather_nd(data, indices)
|
|
1025
|
-
return result
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
768
|
def _generate_indices_from_tuple_of_tensor(tuple_index, op_name):
|
|
1029
769
|
"""Generate an indices tensor from a tuple of tensor."""
|
|
1030
770
|
indexes_types = hyper_map(F.dtype, tuple_index)
|
|
@@ -1116,8 +856,15 @@ def sequence_to_tensor(value, dtype):
|
|
|
1116
856
|
|
|
1117
857
|
if value_elements_type == const_utils.ALL_TENSOR:
|
|
1118
858
|
value = F.stack(value).astype(dtype)
|
|
1119
|
-
elif value_elements_type == const_utils.NO_TENSOR
|
|
1120
|
-
|
|
859
|
+
elif value_elements_type == const_utils.NO_TENSOR:
|
|
860
|
+
if isinstance(value, list):
|
|
861
|
+
value = tuple(value)
|
|
862
|
+
|
|
863
|
+
if dtype == mstype.float16:
|
|
864
|
+
value = TupleToTensor()(value, mstype.float32)
|
|
865
|
+
value = F.cast(value, dtype)
|
|
866
|
+
else:
|
|
867
|
+
value = TupleToTensor()(value, dtype)
|
|
1121
868
|
else:
|
|
1122
869
|
new_value = ()
|
|
1123
870
|
for ele in value:
|
|
@@ -1138,57 +885,31 @@ def _generate_updates_from_sequence(data, index, value, op_type):
|
|
|
1138
885
|
def _generate_updates_from_tensor(data, index, value, op_type):
|
|
1139
886
|
"""Generate an updates tensor from a tensor."""
|
|
1140
887
|
value = value.astype(data.dtype)
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type, True)
|
|
1145
|
-
updates = ops.broadcast_to(value, updates_shape)
|
|
1146
|
-
return updates
|
|
1147
|
-
updates_shape = const_utils.generate_updates_shape(data.shape, index.shape, op_type, False)
|
|
1148
|
-
need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value.shape)
|
|
1149
|
-
if need_broadcast:
|
|
1150
|
-
return _broadcast(updates_shape, value)
|
|
1151
|
-
return value
|
|
888
|
+
updates_shape = const_utils.generate_updates_shape(data.shape, index.shape, op_type)
|
|
889
|
+
updates = ops.broadcast_to(value, updates_shape)
|
|
890
|
+
return updates
|
|
1152
891
|
|
|
1153
892
|
|
|
1154
893
|
# Tensor getitem implementations are above this line, setitem implementations below.
|
|
1155
894
|
|
|
1156
|
-
def
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
return tensor_setitem_by_tensor_with_tensor(self, index, value)
|
|
1161
|
-
return tensor_setitem_by_tensor_with_sequence(self, index, value)
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
def tensor_setitem_by_tuple(self, index, value):
|
|
1165
|
-
index = convert_tupleslice_to_tensor(index)
|
|
1166
|
-
if isinstance(value, (int, float, bool)):
|
|
1167
|
-
index = format_tuple_indices(index)
|
|
1168
|
-
return tensor_setitem_by_tuple_with_number(self, index, value)
|
|
1169
|
-
if isinstance(value, Tensor):
|
|
1170
|
-
return tensor_setitem_by_tuple_with_tensor(self, index, value)
|
|
1171
|
-
return tensor_setitem_by_tuple_with_sequence(self, index, value)
|
|
895
|
+
def _tensor_index_transfer(index, broadcast_shape, final_shape, new_shape):
|
|
896
|
+
"""Transform tuple index tensor to the required."""
|
|
897
|
+
if 0 in final_shape:
|
|
898
|
+
return F.fill(index.dtype, final_shape, 0)
|
|
1172
899
|
|
|
900
|
+
if broadcast_shape == ():
|
|
901
|
+
# broadcast_to () is not support on Ascend
|
|
902
|
+
item = index
|
|
903
|
+
else:
|
|
904
|
+
item = F.broadcast_to(index, broadcast_shape)
|
|
905
|
+
item = F.reshape(item, new_shape)
|
|
906
|
+
return F.broadcast_to(item, final_shape)
|
|
1173
907
|
|
|
1174
|
-
def tensor_setitem_by_number(self, index, value):
|
|
1175
|
-
if isinstance(value, (int, float, bool)):
|
|
1176
|
-
return tensor_setitem_by_number_with_number(self, index, value)
|
|
1177
|
-
if isinstance(value, Tensor):
|
|
1178
|
-
return tensor_setitem_by_number_with_tensor(self, index, value)
|
|
1179
|
-
return tensor_setitem_by_number_with_sequence(self, index, value)
|
|
1180
908
|
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
if not all_empty_tensor:
|
|
1186
|
-
x = F.broadcast_to(x, broadcast_shape)
|
|
1187
|
-
x = F.reshape(x, new_shape)
|
|
1188
|
-
x = F.broadcast_to(x, final_shape)
|
|
1189
|
-
return x
|
|
1190
|
-
item = _broadcast(broadcast_shape, x)
|
|
1191
|
-
return _broadcast(final_shape, F.reshape(item, new_shape))
|
|
909
|
+
def reshape_with_check(x, new_shape):
|
|
910
|
+
if isinstance(new_shape, Tensor):
|
|
911
|
+
new_shape = TensorToTuple()(new_shape)
|
|
912
|
+
return F.reshape(x, new_shape)
|
|
1192
913
|
|
|
1193
914
|
|
|
1194
915
|
class _TensorIndexSetitem(base.TensorIndexSetitem_):
|
|
@@ -1218,9 +939,10 @@ def tensor_setitem_by_slice(self, index, value):
|
|
|
1218
939
|
return self
|
|
1219
940
|
value = F.broadcast_to(value, value_shape)
|
|
1220
941
|
if not const_utils.is_ascend() and step == 1:
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
942
|
+
start = (start,)
|
|
943
|
+
stop = (stop,)
|
|
944
|
+
step = (step,)
|
|
945
|
+
return copy_slice(self, value, start, stop, step)
|
|
1224
946
|
return F.tensor_scatter_update(self, indices, value)
|
|
1225
947
|
|
|
1226
948
|
|
|
@@ -1236,14 +958,14 @@ def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
|
|
|
1236
958
|
"""Set a tensor item by an int tensor with a tensor."""
|
|
1237
959
|
if F.rank(index) == 0:
|
|
1238
960
|
index = F.expand_dims(index, -1)
|
|
1239
|
-
|
|
961
|
+
|
|
1240
962
|
data_shape = F.shape(data)
|
|
963
|
+
updates_shape = index.shape + data_shape[1:]
|
|
964
|
+
value = F.cast(value, F.dtype(data))
|
|
965
|
+
updates = ops.broadcast_to(value, updates_shape)
|
|
1241
966
|
first_val = data_shape[0]
|
|
1242
967
|
index = F.select(index < 0, index + first_val, index)
|
|
1243
968
|
index = F.expand_dims(index, -1)
|
|
1244
|
-
if F.rank(index) < 2:
|
|
1245
|
-
index = F.expand_dims(index, 0)
|
|
1246
|
-
updates = F.expand_dims(updates, 0)
|
|
1247
969
|
if is_parameter(data):
|
|
1248
970
|
F.scatter_nd_update(data, index, updates)
|
|
1249
971
|
return data
|
|
@@ -1255,8 +977,7 @@ def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
|
|
|
1255
977
|
index = index.reshape(const_utils.generate_padding_shape(index.shape, len(data.shape)))
|
|
1256
978
|
index = F.broadcast_to(index, data.shape)
|
|
1257
979
|
value = F.cast(value, F.dtype(data))
|
|
1258
|
-
|
|
1259
|
-
value = value.unsqueeze(-1)
|
|
980
|
+
value = value.reshape(const_utils.generate_padding_shape(value.shape, len(data.shape)))
|
|
1260
981
|
value = F.broadcast_to(value, data.shape)
|
|
1261
982
|
result = F.select(index, value, data)
|
|
1262
983
|
return result
|
|
@@ -1269,8 +990,6 @@ def tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
|
|
|
1269
990
|
if tensor_dtype == const_utils.INT_:
|
|
1270
991
|
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
|
|
1271
992
|
|
|
1272
|
-
if F.is_sequence_value_unknown(F.shape(data)):
|
|
1273
|
-
return tensor_setitem_by_tuple_with_tensor(data, (index,), value_tensor.astype(data.dtype))
|
|
1274
993
|
return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
|
|
1275
994
|
|
|
1276
995
|
|
|
@@ -1281,33 +1000,8 @@ def tensor_setitem_by_tensor_with_number(data, index, value):
|
|
|
1281
1000
|
|
|
1282
1001
|
def tensor_setitem_by_tensor_with_sequence(data, index, value):
|
|
1283
1002
|
"""Assigns the tensor by tensor with tuple value."""
|
|
1284
|
-
index_dtype = F.dtype(index)
|
|
1285
|
-
if index_dtype in (mstype.int32, mstype.int64):
|
|
1286
|
-
return _tensor_setitem_by_tensor_with_sequence(data, index, value)
|
|
1287
|
-
if index_dtype == mstype.bool_:
|
|
1288
|
-
return _tensor_setitem_by_bool_tensor_with_sequence(data, index, value)
|
|
1289
|
-
exp_msg = const_utils.gen_exception_msg("The tensor index must be int or bool type, but got {}.", index_dtype)
|
|
1290
|
-
const_utils.raise_index_error(exp_msg)
|
|
1291
|
-
return None
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
def _tensor_setitem_by_tensor_with_sequence(data, index, value):
|
|
1295
|
-
"""Set a tensor item by a tensor with a tuple."""
|
|
1296
|
-
updates = _generate_updates_from_sequence(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
|
|
1297
|
-
index = F.expand_dims(index, -1)
|
|
1298
|
-
return F.tensor_scatter_update(data, index, updates)
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
def _tensor_setitem_by_bool_tensor_with_sequence(data, index, value):
|
|
1302
|
-
"""Set a tensor item by a bool tensor with a tuple."""
|
|
1303
1003
|
value = sequence_to_tensor(value, F.dtype(data))
|
|
1304
|
-
return
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
def tensor_setitem_by_slice_with_number(data, input_slice, value):
|
|
1308
|
-
"""Givens a scalar assign to tensor by slice"""
|
|
1309
|
-
value = F.cast(value, F.dtype(data))
|
|
1310
|
-
return tensor_setitem_by_slice_with_tensor(data, input_slice, value)
|
|
1004
|
+
return tensor_setitem_by_tensor_with_tensor(data, index, value)
|
|
1311
1005
|
|
|
1312
1006
|
|
|
1313
1007
|
def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
|
|
@@ -1316,78 +1010,14 @@ def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
|
|
|
1316
1010
|
return tensor_setitem_by_tuple_with_tensor(data, tuple_index, value)
|
|
1317
1011
|
|
|
1318
1012
|
|
|
1319
|
-
def
|
|
1320
|
-
"""
|
|
1321
|
-
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
stop_tensor = stack((stop,))
|
|
1325
|
-
step_tensor = stack((step,))
|
|
1326
|
-
dim0_size = stop_tensor - start_tensor
|
|
1327
|
-
if dim0_size <= 0:
|
|
1328
|
-
return data
|
|
1329
|
-
if dim0_size >= data_shape[0]:
|
|
1330
|
-
dim0_size = data_shape[0:1]
|
|
1331
|
-
value_shape = P.Concat(-1)((dim0_size, data_shape[1:]))
|
|
1332
|
-
value = ops.broadcast_to(value, value_shape)
|
|
1333
|
-
return copy_slice(data, value.astype(data.dtype), start_tensor, stop_tensor, step_tensor)
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
|
|
1337
|
-
"""Assigns a tensor value to the tensor by slice."""
|
|
1338
|
-
result = None
|
|
1339
|
-
check_result = const_utils.check_tensor_setitem_index(input_slice)
|
|
1340
|
-
if check_result:
|
|
1341
|
-
data_shape = F.shape(data)
|
|
1342
|
-
step = const_utils.get_step_from_slice(input_slice)
|
|
1343
|
-
if step == 1 and not const_utils.is_ascend():
|
|
1344
|
-
if F.is_sequence_value_unknown(data_shape):
|
|
1345
|
-
return tensor_copy_slice_from_slice(data, input_slice, value)
|
|
1346
|
-
start, stop, step = const_utils.normalize_slice(input_slice, data.shape[0])
|
|
1347
|
-
dim0_size = stop - start
|
|
1348
|
-
if dim0_size <= 0:
|
|
1349
|
-
return data
|
|
1350
|
-
value_shape = (dim0_size,) + const_utils.tuple_slice(data.shape, 1, None)
|
|
1351
|
-
value = _broadcast(value_shape, value)
|
|
1352
|
-
return copy_slice(data, value.astype(data.dtype), (start,), (stop,), (step,))
|
|
1353
|
-
if F.is_sequence_value_unknown(data_shape):
|
|
1354
|
-
const_utils.raise_unimplemented_error(
|
|
1355
|
-
"Not supported to take the subscript of dynamic shape tensor slice setitem")
|
|
1356
|
-
indices = const_utils.slice2indices(input_slice, data_shape)
|
|
1357
|
-
if indices is False:
|
|
1358
|
-
return data
|
|
1359
|
-
value_shape = const_utils.tuple_slice(F.shape(indices), None, -1)
|
|
1360
|
-
value = _broadcast(value_shape, value)
|
|
1361
|
-
result = F.tensor_scatter_update(data, indices, value.astype(F.dtype(data)))
|
|
1362
|
-
return result
|
|
1363
|
-
|
|
1364
|
-
|
|
1365
|
-
def tensor_setitem_by_slice_with_sequence(data, input_slice, value):
|
|
1366
|
-
"""Assigns a list/tuple value to the tensor by slice."""
|
|
1367
|
-
value = _generate_updates_from_sequence(data, input_slice, value, const_utils.SET_ITEM_BY_NON_TENSOR)
|
|
1368
|
-
return tensor_setitem_by_slice_with_tensor(data, input_slice, value)
|
|
1013
|
+
def tensor_setitem_by_list(data, index, value):
|
|
1014
|
+
"""list indices will be converted to tuple or tensor based on its contents."""
|
|
1015
|
+
index_tensor = _convert_list_index_to_tensor(index)
|
|
1016
|
+
if index_tensor is not None:
|
|
1017
|
+
return tensor_setitem_by_tensor_with_tensor(data, index_tensor, value)
|
|
1369
1018
|
|
|
1019
|
+
return tensor_setitem_by_tuple_with_tensor(data, tuple(index), value)
|
|
1370
1020
|
|
|
1371
|
-
def tensor_copy_slice_from_tuple(data, tuple_index, value):
|
|
1372
|
-
"""using TensorCopySlices by fixed model tuple."""
|
|
1373
|
-
data_shape = F.dyn_shape(data)
|
|
1374
|
-
dim1_start, dim1_stop, _ = get_slice_stride(tuple_index[1], data_shape[1])
|
|
1375
|
-
if dim1_stop - dim1_start <= 0:
|
|
1376
|
-
return data
|
|
1377
|
-
dim0_start = _scalar_to_tensor(tuple_index[0])
|
|
1378
|
-
dim0_stop = dim0_start + const_utils.make_tensor(1)
|
|
1379
|
-
start = (dim0_start, dim1_start)
|
|
1380
|
-
stop = (dim0_stop, dim1_stop)
|
|
1381
|
-
step = (const_utils.make_tensor(1), const_utils.make_tensor(1))
|
|
1382
|
-
start_tensor = stack(start)
|
|
1383
|
-
stop_tensor = stack(stop)
|
|
1384
|
-
step_tensor = stack(step)
|
|
1385
|
-
dim1_size = stack((dim1_stop - dim1_start,))
|
|
1386
|
-
if dim1_size > data_shape[1]:
|
|
1387
|
-
dim1_size = data_shape[1:2]
|
|
1388
|
-
value_shape = P.Concat(-1)((dim1_size, data_shape[2:]))
|
|
1389
|
-
value = ops.broadcast_to(value, value_shape)
|
|
1390
|
-
return copy_slice(data, value.astype(data.dtype), start_tensor, stop_tensor, step_tensor)
|
|
1391
1021
|
|
|
1392
1022
|
|
|
1393
1023
|
class _PreSetitemByTuple(base.PreSetitemByTuple_):
|
|
@@ -1436,50 +1066,28 @@ class _HandleBoolTensor(base.HandleBoolTensor_):
|
|
|
1436
1066
|
_handle_bool_tensor = _HandleBoolTensor('handle_bool_tensor')
|
|
1437
1067
|
|
|
1438
1068
|
|
|
1439
|
-
class _HandleScalarTensorIndex(base.HandleScalarTensorIndex_):
|
|
1440
|
-
"""
|
|
1441
|
-
Getting item of Tensor.
|
|
1442
|
-
|
|
1443
|
-
Args:
|
|
1444
|
-
data (Tensor): A tuple to be sliced.
|
|
1445
|
-
index: Index of tensor.
|
|
1446
|
-
|
|
1447
|
-
Returns:
|
|
1448
|
-
Type is the same as the element type of data.
|
|
1449
|
-
"""
|
|
1450
|
-
|
|
1451
|
-
def __init__(self, name):
|
|
1452
|
-
"""Initialize _HandleBoolTensor."""
|
|
1453
|
-
base.HandleScalarTensorIndex_.__init__(self, name)
|
|
1454
|
-
|
|
1455
|
-
def __call__(self, *args):
|
|
1456
|
-
pass
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
_handle_scalar_tensor_index = _HandleScalarTensorIndex('handle_scalar_tensor_index')
|
|
1460
|
-
|
|
1461
|
-
|
|
1462
1069
|
def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
|
1463
1070
|
"""Assigns the tensor by tuple with tensor value."""
|
|
1464
1071
|
if const_utils.use_copy_slice(tuple_index) and not const_utils.is_ascend():
|
|
1465
|
-
if F.is_sequence_value_unknown(F.shape(data)):
|
|
1466
|
-
return tensor_copy_slice_from_tuple(data, tuple_index, value)
|
|
1467
1072
|
dim1_start, dim1_stop, _ = const_utils.normalize_slice(
|
|
1468
1073
|
tuple_index[1], data.shape[1])
|
|
1074
|
+
if isinstance(dim1_start, Tensor):
|
|
1075
|
+
dim1_start = int(dim1_start)
|
|
1076
|
+
if isinstance(dim1_stop, Tensor):
|
|
1077
|
+
dim1_stop = int(dim1_stop)
|
|
1469
1078
|
if dim1_stop - dim1_start <= 0:
|
|
1470
1079
|
return data
|
|
1471
1080
|
dim0_start = tuple_index[0] if tuple_index[0] >= 0 else tuple_index[0] + data.shape[0]
|
|
1472
1081
|
start = (dim0_start, dim1_start)
|
|
1473
1082
|
stop = (dim0_start + 1, dim1_stop)
|
|
1474
1083
|
step = (1, 1)
|
|
1475
|
-
value_shape = (dim1_stop - dim1_start,) +
|
|
1476
|
-
|
|
1477
|
-
value = _broadcast(value_shape, value)
|
|
1084
|
+
value_shape = (dim1_stop - dim1_start,) + data.shape[2:]
|
|
1085
|
+
value = F.broadcast_to(value, value_shape)
|
|
1478
1086
|
return copy_slice(data, value.astype(data.dtype), start, stop, step)
|
|
1479
1087
|
tuple_index, _, non_zero_shapes = _handle_bool_tensor(tuple_index)
|
|
1480
1088
|
|
|
1481
1089
|
for non_zero_shape in non_zero_shapes:
|
|
1482
|
-
if
|
|
1090
|
+
if 0 in non_zero_shape:
|
|
1483
1091
|
return data
|
|
1484
1092
|
value = value.astype(data.dtype)
|
|
1485
1093
|
special_index, tuple_index, new_value_shape, idx_advanced, _broadcast_data_shape \
|
|
@@ -1512,17 +1120,19 @@ def tensor_itemset_by_tuple_with_tensor(data, tuple_index, value):
|
|
|
1512
1120
|
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
|
1513
1121
|
|
|
1514
1122
|
if const_utils.use_copy_slice(tuple_index) and not const_utils.is_ascend():
|
|
1515
|
-
if F.is_sequence_value_unknown(F.shape(data)):
|
|
1516
|
-
return tensor_copy_slice_from_tuple(data, tuple_index, value)
|
|
1517
1123
|
dim1_start, dim1_stop, _ = const_utils.normalize_slice(tuple_index[1], data.shape[1])
|
|
1124
|
+
if isinstance(dim1_start, Tensor):
|
|
1125
|
+
dim1_start = int(dim1_start)
|
|
1126
|
+
if isinstance(dim1_stop, Tensor):
|
|
1127
|
+
dim1_stop = int(dim1_stop)
|
|
1518
1128
|
if dim1_stop - dim1_start <= 0:
|
|
1519
1129
|
return data
|
|
1520
1130
|
dim0_start = tuple_index[0] if tuple_index[0] >= 0 else tuple_index[0] + data.shape[0]
|
|
1521
1131
|
start = (dim0_start, dim1_start)
|
|
1522
1132
|
stop = (dim0_start + 1, dim1_stop)
|
|
1523
1133
|
step = (1, 1)
|
|
1524
|
-
value_shape = (dim1_stop - dim1_start,) +
|
|
1525
|
-
value =
|
|
1134
|
+
value_shape = (dim1_stop - dim1_start,) + data.shape[2:]
|
|
1135
|
+
value = F.broadcast_to(value, value_shape)
|
|
1526
1136
|
return copy_slice(data, value.astype(data.dtype), start, stop, step)
|
|
1527
1137
|
tuple_index, value, idx_advanced = remove_expanded_dims(tuple_index, F.shape(data), value)
|
|
1528
1138
|
|
|
@@ -1545,49 +1155,45 @@ def tensor_itemset_by_tuple_with_tensor(data, tuple_index, value):
|
|
|
1545
1155
|
|
|
1546
1156
|
|
|
1547
1157
|
def tensor_setitem_by_tuple_with_sequence(data, tuple_index, value):
|
|
1548
|
-
value =
|
|
1158
|
+
value = sequence_to_tensor(value, F.dtype(data))
|
|
1549
1159
|
return tensor_setitem_by_tuple_with_tensor(data, tuple_index, value)
|
|
1550
1160
|
|
|
1551
1161
|
|
|
1552
1162
|
def tensor_setitem_by_number_with_number(data, index, value):
|
|
1553
1163
|
"""Assigns the tensor by number with number value."""
|
|
1554
|
-
|
|
1555
|
-
|
|
1164
|
+
data_shape = F.shape(data)
|
|
1165
|
+
dim_size = data_shape[0]
|
|
1166
|
+
if index < 0:
|
|
1167
|
+
index += dim_size
|
|
1168
|
+
if index < -dim_size or index >= dim_size:
|
|
1169
|
+
raise IndexError(f'index {index} is out of bounds for axis 0 with size {dim_size}')
|
|
1170
|
+
index = F.cast(index, mstype.int64)
|
|
1171
|
+
index = F.reshape(index, (1, 1))
|
|
1172
|
+
|
|
1173
|
+
updates = F.cast(value, data.dtype)
|
|
1174
|
+
updates_shape = (1,) + data_shape[1:]
|
|
1175
|
+
updates = ops.broadcast_to(updates, updates_shape)
|
|
1176
|
+
|
|
1177
|
+
if is_parameter(data):
|
|
1178
|
+
F.scatter_nd_update(data, index, updates)
|
|
1179
|
+
return data
|
|
1180
|
+
return F.tensor_scatter_update(data, index, updates)
|
|
1556
1181
|
|
|
1557
1182
|
|
|
1558
1183
|
def tensor_setitem_by_number_with_sequence(data, index, value):
|
|
1559
1184
|
"""Assigns a list/tuple value to the tensor by slice."""
|
|
1560
|
-
value =
|
|
1185
|
+
value = sequence_to_tensor(value, F.dtype(data))
|
|
1561
1186
|
return tensor_setitem_by_number_with_tensor(data, index, value)
|
|
1562
1187
|
|
|
1563
1188
|
|
|
1564
1189
|
def tensor_setitem_by_number_with_tensor(data, index, value):
|
|
1565
|
-
|
|
1566
|
-
data_shape = F.shape(data)
|
|
1567
|
-
if F.is_sequence_value_unknown(data_shape):
|
|
1568
|
-
index = _scalar_to_tensor(index)
|
|
1569
|
-
index = F.expand_dims(index, -1)
|
|
1570
|
-
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value)
|
|
1571
|
-
|
|
1572
|
-
dim_size = data_shape[0]
|
|
1573
|
-
if index < -dim_size or index >= dim_size:
|
|
1574
|
-
raise IndexError(f'index {index} is out of bounds for axis 0 with size {dim_size}')
|
|
1575
|
-
index = const_utils.int_to_index(index, data_shape)
|
|
1576
|
-
value_shape = const_utils.tuple_slice(F.shape(index), None, -1)
|
|
1577
|
-
value = _broadcast(value_shape, value.astype(F.dtype(data)))
|
|
1578
|
-
if is_parameter(data):
|
|
1579
|
-
F.scatter_nd_update(data, index, value)
|
|
1580
|
-
return data
|
|
1581
|
-
return F.tensor_scatter_update(data, index, value)
|
|
1190
|
+
return tensor_setitem_by_number_with_number(data, index, value)
|
|
1582
1191
|
|
|
1583
1192
|
|
|
1584
1193
|
def tensor_setitem_by_ellipsis_with_number(data, value):
|
|
1585
1194
|
"""Assigns the tensor by ellipsis with number value."""
|
|
1586
1195
|
data_shape = F.shape(data)
|
|
1587
1196
|
data_dtype = F.dtype(data)
|
|
1588
|
-
if F.is_sequence_value_unknown(data_shape):
|
|
1589
|
-
value = F.cast(value, F.dtype(data))
|
|
1590
|
-
return tensor_setitem_by_ellipsis_with_tensor(data, value)
|
|
1591
1197
|
return F.fill(data_dtype, data_shape, value)
|
|
1592
1198
|
|
|
1593
1199
|
|
|
@@ -1597,21 +1203,20 @@ def tensor_setitem_by_ellipsis_with_tensor(data, value):
|
|
|
1597
1203
|
data_dtype = F.dtype(data)
|
|
1598
1204
|
value = value.astype(data_dtype)
|
|
1599
1205
|
|
|
1600
|
-
if F.is_sequence_value_unknown(data_shape):
|
|
1601
|
-
data_shape = F.dyn_shape(data)
|
|
1602
|
-
data = ops.broadcast_to(value, data_shape)
|
|
1603
|
-
return data
|
|
1604
1206
|
value_shape = F.shape(value)
|
|
1605
|
-
|
|
1207
|
+
|
|
1208
|
+
if len(value_shape) > len(data_shape):
|
|
1209
|
+
source_shape = data_shape
|
|
1210
|
+
else:
|
|
1211
|
+
source_shape = value_shape
|
|
1606
1212
|
value = F.reshape(value, source_shape)
|
|
1607
|
-
|
|
1608
|
-
data = F.cast(value, data_dtype)
|
|
1213
|
+
data = F.broadcast_to(value, data_shape)
|
|
1609
1214
|
return data
|
|
1610
1215
|
|
|
1611
1216
|
|
|
1612
1217
|
def tensor_setitem_by_ellipsis_with_sequence(data, value):
|
|
1613
1218
|
"""Assigns a list/tuple value to the tensor by ellipsis."""
|
|
1614
|
-
value =
|
|
1219
|
+
value = sequence_to_tensor(value, F.dtype(data))
|
|
1615
1220
|
return tensor_setitem_by_ellipsis_with_tensor(data, value)
|
|
1616
1221
|
|
|
1617
1222
|
|
|
@@ -1622,23 +1227,18 @@ def tensor_setitem_by_bool(data, index, value):
|
|
|
1622
1227
|
if not index:
|
|
1623
1228
|
data_shape = (0,) + data_shape
|
|
1624
1229
|
if isinstance(value, (list, tuple)):
|
|
1625
|
-
value =
|
|
1626
|
-
|
|
1627
|
-
value =
|
|
1628
|
-
|
|
1629
|
-
value = const_utils.make_tensor(value, mstype.float32)
|
|
1630
|
-
|
|
1631
|
-
if F.is_sequence_value_unknown(data_shape) and index:
|
|
1632
|
-
data_shape = F.dyn_shape(data)
|
|
1633
|
-
value = value.astype(data_dtype)
|
|
1634
|
-
data = ops.broadcast_to(value, data_shape)
|
|
1635
|
-
return data
|
|
1636
|
-
value_shape = F.shape(value)
|
|
1637
|
-
source_shape = const_utils.get_source_shape(data_shape, value_shape)
|
|
1230
|
+
value = sequence_to_tensor(value, data_dtype)
|
|
1231
|
+
else:
|
|
1232
|
+
value = F.cast(value, data_dtype)
|
|
1233
|
+
|
|
1638
1234
|
if index:
|
|
1235
|
+
value_shape = F.shape(value)
|
|
1236
|
+
if len(value_shape) > len(data_shape):
|
|
1237
|
+
source_shape = data_shape
|
|
1238
|
+
else:
|
|
1239
|
+
source_shape = value_shape
|
|
1639
1240
|
value = F.reshape(value, source_shape)
|
|
1640
|
-
|
|
1641
|
-
data = F.cast(value, data_dtype)
|
|
1241
|
+
data = F.broadcast_to(value, data_shape)
|
|
1642
1242
|
return data
|
|
1643
1243
|
|
|
1644
1244
|
|
|
@@ -1651,33 +1251,6 @@ def tensor_in_sequence(x, y):
|
|
|
1651
1251
|
return result
|
|
1652
1252
|
|
|
1653
1253
|
|
|
1654
|
-
def format_list_indices(list_indices, length):
|
|
1655
|
-
"""Convert list indices to tensor or tuple indices based on its contents."""
|
|
1656
|
-
indices_types = hyper_map(F.typeof, list_indices)
|
|
1657
|
-
# If eyery element in list is bool, it's treated as 1-D bool tensor.
|
|
1658
|
-
# If every element in list is int(not all bool), it's treated as int tensor.
|
|
1659
|
-
if const_utils.judge_indexes_types(indices_types, mstype.int_type + (mstype.bool_,)):
|
|
1660
|
-
if not F.isconstant(length):
|
|
1661
|
-
return const_utils.sequence_to_index(list_indices, None)
|
|
1662
|
-
return const_utils.sequence_to_index(list_indices, length)
|
|
1663
|
-
# If list contains other types(.../list/tuple/None), it's treated as a tuple
|
|
1664
|
-
return const_utils.deep_tuple(list_indices)
|
|
1665
|
-
|
|
1666
|
-
|
|
1667
|
-
def format_tuple_indices(tuple_indices):
|
|
1668
|
-
"""
|
|
1669
|
-
Format tuple indices by unpacking high-dimension tuple and removing expand
|
|
1670
|
-
dimension signs(Bool and None).
|
|
1671
|
-
"""
|
|
1672
|
-
res = ()
|
|
1673
|
-
for i in tuple_indices:
|
|
1674
|
-
if isinstance(i, (list, tuple)):
|
|
1675
|
-
res += (const_utils.unpack(i),)
|
|
1676
|
-
else:
|
|
1677
|
-
res += (i,)
|
|
1678
|
-
return res
|
|
1679
|
-
|
|
1680
|
-
|
|
1681
1254
|
@_primexpr
|
|
1682
1255
|
def remove_expanded_dims_parse_bool_tensor_index(index_out, indices_out, shapes, cur_dim):
|
|
1683
1256
|
""" Parse bool tensor index """
|
|
@@ -1830,7 +1403,7 @@ def reduce_(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None,
|
|
|
1830
1403
|
return reduce_fn(a, axes).astype(dtype)
|
|
1831
1404
|
|
|
1832
1405
|
|
|
1833
|
-
tensor_operator_registry
|
|
1406
|
+
setattr(tensor_operator_registry, "reduce", reduce_)
|
|
1834
1407
|
|
|
1835
1408
|
|
|
1836
1409
|
def check_indices(dims, indices, mode, allow_negative_index=True):
|
|
@@ -1857,7 +1430,7 @@ def check_indices(dims, indices, mode, allow_negative_index=True):
|
|
|
1857
1430
|
return clipped
|
|
1858
1431
|
|
|
1859
1432
|
|
|
1860
|
-
tensor_operator_registry
|
|
1433
|
+
setattr(tensor_operator_registry, 'check_indices', check_indices)
|
|
1861
1434
|
|
|
1862
1435
|
|
|
1863
1436
|
def convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape, slice_shapes, fancy_position):
|