mindspore 2.2.14__cp39-cp39-manylinux1_x86_64.whl → 2.3.0rc1__cp39-cp39-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -4
- mindspore/_akg/akg/composite/build_module.py +155 -11
- mindspore/_akg/akg/config/repository.json +38 -0
- mindspore/_akg/akg/ms/info_version_adapt.py +29 -0
- mindspore/_akg/akg/tvm/contrib/nvcc.py +4 -1
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +2 -1
- mindspore/_akg/akg/utils/composite_op_helper.py +4 -2
- mindspore/_akg/akg/utils/dump_ascend_meta.py +2 -2
- mindspore/_akg/akg/utils/gen_random.py +14 -8
- mindspore/_akg/akg/utils/op_dsl.py +11 -0
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +5 -5
- mindspore/_c_dataengine.cpython-39-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-39-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-39-x86_64-linux-gnu.so +0 -0
- mindspore/_checkparam.py +58 -0
- mindspore/_extends/builtin_operations.py +2 -1
- mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
- mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
- mindspore/_extends/parse/__init__.py +18 -14
- mindspore/_extends/parse/compile_config.py +229 -0
- mindspore/_extends/parse/parser.py +155 -59
- mindspore/_extends/parse/resources.py +40 -7
- mindspore/_extends/parse/standard_method.py +124 -204
- mindspore/_extends/remote/kernel_build_server.py +2 -0
- mindspore/_mindspore_offline_debug.cpython-39-x86_64-linux-gnu.so +0 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +24 -18
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost_cell_wrapper.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +3 -1
- mindspore/common/_jit_fallback_utils.py +2 -3
- mindspore/common/_register_for_adapter.py +7 -0
- mindspore/common/_stub_tensor.py +6 -1
- mindspore/common/_utils.py +5 -17
- mindspore/common/api.py +91 -48
- mindspore/common/auto_dynamic_shape.py +27 -14
- mindspore/common/dtype.py +5 -4
- mindspore/common/dump.py +5 -4
- mindspore/common/initializer.py +1 -1
- mindspore/common/jit_config.py +20 -11
- mindspore/common/lazy_inline.py +58 -17
- mindspore/common/mindir_util.py +12 -2
- mindspore/common/mutable.py +79 -14
- mindspore/common/parameter.py +19 -4
- mindspore/common/seed.py +9 -9
- mindspore/common/sparse_tensor.py +251 -18
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +321 -433
- mindspore/communication/__init__.py +3 -3
- mindspore/communication/_comm_helper.py +5 -0
- mindspore/communication/management.py +53 -38
- mindspore/config/op_info.config +22 -54
- mindspore/context.py +167 -59
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +6 -6
- mindspore/dataset/audio/transforms.py +711 -158
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/engine/cache_client.py +2 -2
- mindspore/dataset/engine/datasets.py +72 -38
- mindspore/dataset/engine/datasets_audio.py +14 -14
- mindspore/dataset/engine/datasets_standard_format.py +33 -3
- mindspore/dataset/engine/datasets_text.py +38 -38
- mindspore/dataset/engine/datasets_user_defined.py +7 -7
- mindspore/dataset/engine/datasets_vision.py +75 -71
- mindspore/dataset/engine/offload.py +5 -7
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +408 -121
- mindspore/dataset/text/utils.py +9 -9
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/transforms.py +261 -76
- mindspore/dataset/utils/browse_dataset.py +9 -9
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +5 -5
- mindspore/dataset/vision/transforms.py +2264 -514
- mindspore/dataset/vision/utils.py +40 -9
- mindspore/dataset/vision/validators.py +7 -1
- mindspore/experimental/optim/__init__.py +12 -2
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +35 -34
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +40 -16
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +60 -119
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +15 -8
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +28 -19
- mindspore/hal/__init__.py +34 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/stream.py +337 -0
- mindspore/include/api/data_type.h +2 -2
- mindspore/include/api/dual_abi_helper.h +16 -3
- mindspore/include/api/model.h +1 -3
- mindspore/include/api/status.h +14 -0
- mindspore/include/c_api/model_c.h +173 -0
- mindspore/include/c_api/ms/base/types.h +1 -0
- mindspore/include/c_api/types_c.h +19 -0
- mindspore/include/dataset/execute.h +1 -3
- mindspore/include/mindapi/base/format.h +125 -23
- mindspore/include/mindapi/base/types.h +7 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libmpi_collective.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +2044 -154
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +2044 -33
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/build_tbe_kernel.py +529 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/compiler.py +56 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/custom.py +1109 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/get_file_path.py +36 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +0 -2
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/tbe_topi.py +556 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +0 -2
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +6325 -1767
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_add_custom.h +49 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_decoder_kv_cache.h +59 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_prompt_kv_cache.h +59 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/lib/libcust_opapi.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +52 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +232 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +232 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/add_custom.cpp +81 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/add_custom.py +134 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/decoder_kv_cache.cpp +192 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/decoder_kv_cache.py +134 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/prompt_kv_cache.cpp +274 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/prompt_kv_cache.py +134 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/op_tiling/lib/linux/x86_64/libcust_opmaster_rt2.0.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/op_tiling/liboptiling.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_proto/inc/op_proto.h +39 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_proto/lib/linux/x86_64/libcust_opsproto_rt2.0.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu10.1/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/{libmindspore_ascend.so.1 → libmindspore_ascend.so.2} +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/mindrecord/__init__.py +5 -1
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +25 -0
- mindspore/mindrecord/filewriter.py +74 -56
- mindspore/mindrecord/mindpage.py +40 -6
- mindspore/mindrecord/shardutils.py +3 -2
- mindspore/mindrecord/shardwriter.py +7 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +8 -13
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -15
- mindspore/mindrecord/tools/csv_to_mr.py +4 -9
- mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
- mindspore/mindrecord/tools/mnist_to_mr.py +7 -12
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -6
- mindspore/multiprocessing/__init__.py +68 -0
- mindspore/nn/cell.py +86 -133
- mindspore/nn/dynamic_lr.py +2 -2
- mindspore/nn/layer/activation.py +79 -90
- mindspore/nn/layer/basic.py +4 -80
- mindspore/nn/layer/channel_shuffle.py +3 -16
- mindspore/nn/layer/container.py +3 -3
- mindspore/nn/layer/conv.py +71 -71
- mindspore/nn/layer/embedding.py +105 -44
- mindspore/nn/layer/image.py +4 -7
- mindspore/nn/layer/normalization.py +46 -38
- mindspore/nn/layer/padding.py +26 -39
- mindspore/nn/layer/pooling.py +13 -9
- mindspore/nn/layer/rnn_cells.py +5 -15
- mindspore/nn/layer/rnns.py +6 -5
- mindspore/nn/layer/thor_layer.py +1 -2
- mindspore/nn/layer/timedistributed.py +1 -1
- mindspore/nn/layer/transformer.py +52 -50
- mindspore/nn/learning_rate_schedule.py +6 -5
- mindspore/nn/loss/loss.py +43 -64
- mindspore/nn/optim/ada_grad.py +4 -2
- mindspore/nn/optim/adadelta.py +3 -1
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +102 -181
- mindspore/nn/optim/adamax.py +4 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +4 -2
- mindspore/nn/optim/ftrl.py +31 -61
- mindspore/nn/optim/lamb.py +5 -3
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +6 -4
- mindspore/nn/optim/momentum.py +13 -25
- mindspore/nn/optim/optimizer.py +6 -3
- mindspore/nn/optim/proximal_ada_grad.py +4 -2
- mindspore/nn/optim/rmsprop.py +9 -3
- mindspore/nn/optim/rprop.py +4 -2
- mindspore/nn/optim/sgd.py +6 -5
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
- mindspore/nn/probability/distribution/beta.py +2 -2
- mindspore/nn/probability/distribution/categorical.py +4 -6
- mindspore/nn/probability/distribution/cauchy.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +2 -2
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +13 -1
- mindspore/nn/wrap/__init__.py +2 -1
- mindspore/nn/wrap/cell_wrapper.py +33 -12
- mindspore/nn/wrap/grad_reducer.py +148 -8
- mindspore/nn/wrap/loss_scale.py +7 -7
- mindspore/numpy/__init__.py +2 -0
- mindspore/numpy/array_creations.py +2 -0
- mindspore/numpy/array_ops.py +1 -5
- mindspore/numpy/fft.py +431 -0
- mindspore/numpy/math_ops.py +54 -60
- mindspore/numpy/utils.py +3 -0
- mindspore/ops/__init__.py +5 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +4 -129
- mindspore/ops/_grad_experimental/grad_comm_ops.py +16 -22
- mindspore/ops/_grad_experimental/grad_math_ops.py +68 -283
- mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
- mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
- mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/__init__.py +0 -1
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +1 -1
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
- mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -3
- mindspore/ops/_op_impl/cpu/adam.py +2 -2
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
- mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
- mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
- mindspore/ops/_vmap/vmap_array_ops.py +137 -101
- mindspore/ops/_vmap/vmap_base.py +8 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +102 -56
- mindspore/ops/_vmap/vmap_image_ops.py +70 -13
- mindspore/ops/_vmap/vmap_math_ops.py +74 -49
- mindspore/ops/_vmap/vmap_nn_ops.py +164 -89
- mindspore/ops/_vmap/vmap_other_ops.py +1 -1
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +133 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +248 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +147 -0
- mindspore/ops/auto_generate/gen_extend_func.py +130 -0
- mindspore/ops/auto_generate/gen_ops_def.py +4786 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +8335 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +77 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +118 -17
- mindspore/ops/composite/math_ops.py +9 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +166 -601
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +15 -133
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
- mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
- mindspore/ops/deprecated.py +14 -3
- mindspore/ops/extend/__init__.py +46 -0
- mindspore/ops/extend/array_func.py +152 -0
- mindspore/ops/extend/math_func.py +76 -0
- mindspore/ops/{_op_impl/tbe/atomic_addr_clean.py → extend/nn_func.py} +5 -15
- mindspore/ops/function/__init__.py +19 -11
- mindspore/ops/function/array_func.py +251 -1440
- mindspore/ops/function/clip_func.py +12 -13
- mindspore/ops/function/debug_func.py +1 -4
- mindspore/ops/function/fft_func.py +31 -0
- mindspore/ops/function/grad/grad_func.py +24 -17
- mindspore/ops/function/image_func.py +27 -21
- mindspore/ops/function/linalg_func.py +35 -68
- mindspore/ops/function/math_func.py +451 -2360
- mindspore/ops/function/nn_func.py +459 -780
- mindspore/ops/function/other_func.py +4 -5
- mindspore/ops/function/parameter_func.py +5 -93
- mindspore/ops/function/random_func.py +24 -80
- mindspore/ops/function/sparse_unary_func.py +9 -16
- mindspore/ops/function/spectral_func.py +1 -1
- mindspore/ops/function/vmap_func.py +14 -14
- mindspore/ops/functional.py +56 -62
- mindspore/ops/op_info_register.py +22 -19
- mindspore/ops/operations/__init__.py +19 -19
- mindspore/ops/operations/_grad_ops.py +20 -723
- mindspore/ops/operations/_inner_ops.py +178 -286
- mindspore/ops/operations/_scalar_ops.py +5 -480
- mindspore/ops/operations/_sequence_ops.py +4 -34
- mindspore/ops/operations/array_ops.py +99 -2491
- mindspore/ops/operations/comm_ops.py +38 -46
- mindspore/ops/operations/custom_ops.py +8 -8
- mindspore/ops/operations/debug_ops.py +100 -31
- mindspore/ops/operations/image_ops.py +1 -217
- mindspore/ops/operations/inner_ops.py +3 -38
- mindspore/ops/operations/linalg_ops.py +1 -49
- mindspore/{rewrite/ast_transformers → ops/operations/manually_defined}/__init__.py +11 -4
- mindspore/ops/operations/manually_defined/_inner.py +61 -0
- mindspore/ops/operations/manually_defined/ops_def.py +1391 -0
- mindspore/ops/operations/math_ops.py +703 -4601
- mindspore/ops/operations/nn_ops.py +374 -1748
- mindspore/ops/operations/other_ops.py +50 -42
- mindspore/ops/operations/random_ops.py +3 -52
- mindspore/ops/primitive.py +196 -96
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +248 -0
- mindspore/ops_generate/arg_handler.py +147 -0
- mindspore/ops_generate/gen_aclnn_implement.py +266 -0
- mindspore/ops_generate/gen_ops.py +1062 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +129 -0
- mindspore/ops_generate/gen_pyboost_func.py +932 -0
- mindspore/ops_generate/gen_utils.py +188 -0
- mindspore/ops_generate/op_proto.py +138 -0
- mindspore/ops_generate/pyboost_utils.py +364 -0
- mindspore/ops_generate/template.py +238 -0
- mindspore/parallel/__init__.py +5 -4
- mindspore/parallel/_auto_parallel_context.py +21 -76
- mindspore/parallel/_cell_wrapper.py +16 -9
- mindspore/parallel/_cost_model_context.py +1 -1
- mindspore/parallel/_dp_allreduce_fusion.py +159 -159
- mindspore/parallel/_parallel_serialization.py +30 -46
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +19 -7
- mindspore/parallel/_transformer/__init__.py +1 -1
- mindspore/parallel/_transformer/layers.py +1 -1
- mindspore/parallel/_transformer/loss.py +1 -1
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/op_parallel_config.py +1 -1
- mindspore/parallel/_transformer/transformer.py +1 -1
- mindspore/parallel/_utils.py +131 -6
- mindspore/parallel/algo_parameter_config.py +6 -6
- mindspore/parallel/checkpoint_transform.py +180 -196
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +345 -0
- mindspore/parallel/cluster/process_entity/_utils.py +116 -0
- mindspore/parallel/cluster/run.py +139 -0
- mindspore/parallel/mpi/__init__.py +1 -1
- mindspore/parallel/mpi/_mpi_config.py +1 -1
- mindspore/parallel/parameter_broadcast.py +152 -0
- mindspore/parallel/shard.py +99 -2
- mindspore/profiler/common/util.py +20 -0
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
- mindspore/profiler/parser/ascend_analysis/constant.py +66 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +77 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +146 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +108 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +80 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +52 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +104 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +59 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +14 -9
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -1
- mindspore/profiler/parser/ascend_flops_generator.py +20 -4
- mindspore/profiler/parser/ascend_hccl_generator.py +25 -277
- mindspore/profiler/parser/ascend_msprof_exporter.py +112 -132
- mindspore/profiler/parser/ascend_msprof_generator.py +68 -285
- mindspore/profiler/parser/ascend_op_generator.py +75 -42
- mindspore/profiler/parser/ascend_timeline_generator.py +293 -135
- mindspore/profiler/parser/base_timeline_generator.py +6 -0
- mindspore/profiler/parser/framework_parser.py +3 -2
- mindspore/profiler/parser/integrator.py +3 -1
- mindspore/profiler/parser/msadvisor_analyzer.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +5 -0
- mindspore/profiler/profiling.py +296 -166
- mindspore/rewrite/__init__.py +2 -13
- mindspore/rewrite/api/node.py +121 -35
- mindspore/rewrite/api/pattern_engine.py +2 -3
- mindspore/rewrite/api/scoped_value.py +16 -15
- mindspore/rewrite/api/symbol_tree.py +45 -29
- mindspore/rewrite/ast_helpers/__init__.py +3 -6
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
- mindspore/rewrite/common/__init__.py +1 -2
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
- mindspore/rewrite/{namer.py → common/namer.py} +63 -18
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/node/__init__.py +5 -5
- mindspore/rewrite/node/call_function.py +23 -7
- mindspore/rewrite/node/cell_container.py +7 -3
- mindspore/rewrite/node/control_flow.py +53 -28
- mindspore/rewrite/node/node.py +212 -196
- mindspore/rewrite/node/node_manager.py +51 -22
- mindspore/rewrite/node/node_topological_manager.py +3 -23
- mindspore/rewrite/parsers/__init__.py +12 -0
- mindspore/rewrite/parsers/arguments_parser.py +8 -9
- mindspore/rewrite/parsers/assign_parser.py +635 -413
- mindspore/rewrite/parsers/attribute_parser.py +3 -4
- mindspore/rewrite/parsers/class_def_parser.py +107 -144
- mindspore/rewrite/parsers/constant_parser.py +5 -5
- mindspore/rewrite/parsers/container_parser.py +4 -6
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +31 -98
- mindspore/rewrite/parsers/function_def_parser.py +13 -5
- mindspore/rewrite/parsers/if_parser.py +28 -10
- mindspore/rewrite/parsers/module_parser.py +8 -182
- mindspore/rewrite/parsers/parser.py +1 -5
- mindspore/rewrite/parsers/parser_register.py +1 -1
- mindspore/rewrite/parsers/return_parser.py +5 -10
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +704 -185
- mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
- mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
- mindspore/run_check/_check_version.py +6 -14
- mindspore/run_check/run_check.py +1 -1
- mindspore/safeguard/rewrite_obfuscation.py +9 -19
- mindspore/scipy/__init__.py +2 -1
- mindspore/scipy/fft.py +133 -0
- mindspore/scipy/linalg.py +140 -55
- mindspore/scipy/ops.py +15 -71
- mindspore/scipy/ops_grad.py +5 -34
- mindspore/scipy/optimize/line_search.py +2 -2
- mindspore/scipy/optimize/minimize.py +1 -1
- mindspore/train/__init__.py +3 -2
- mindspore/train/_utils.py +178 -4
- mindspore/train/amp.py +167 -245
- mindspore/train/callback/_backup_and_restore.py +4 -4
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +39 -13
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +14 -8
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
- mindspore/train/callback/_summary_collector.py +7 -7
- mindspore/train/callback/_time_monitor.py +2 -2
- mindspore/train/data_sink.py +1 -1
- mindspore/train/dataset_helper.py +13 -4
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/accuracy.py +7 -7
- mindspore/train/metrics/confusion_matrix.py +8 -6
- mindspore/train/metrics/cosine_similarity.py +6 -4
- mindspore/train/metrics/error.py +2 -2
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/perplexity.py +2 -1
- mindspore/train/metrics/topk.py +2 -2
- mindspore/train/mind_ir_pb2.py +75 -6
- mindspore/train/model.py +24 -22
- mindspore/train/serialization.py +256 -132
- mindspore/train/summary/summary_record.py +51 -28
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/version.py +1 -1
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc1.dist-info}/METADATA +2 -2
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc1.dist-info}/RECORD +515 -1061
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc1.dist-info}/entry_points.txt +1 -0
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
- mindspore/config/super_bar_config.json +0 -544
- mindspore/gen_ops.py +0 -273
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/nn/layer/flash_attention.py +0 -189
- mindspore/ops/_op_impl/cpu/concat.py +0 -39
- mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
- mindspore/ops/_op_impl/tbe/__init__.py +0 -47
- mindspore/ops/_op_impl/tbe/abs.py +0 -38
- mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/acos.py +0 -37
- mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/acosh.py +0 -37
- mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
- mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
- mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
- mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
- mindspore/ops/_op_impl/tbe/add.py +0 -42
- mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/add_n.py +0 -39
- mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
- mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
- mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
- mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
- mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
- mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
- mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
- mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/asin.py +0 -37
- mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/asinh.py +0 -37
- mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/assign.py +0 -79
- mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
- mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
- mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/atan.py +0 -37
- mindspore/ops/_op_impl/tbe/atan2.py +0 -38
- mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/atanh.py +0 -37
- mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
- mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
- mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
- mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
- mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
- mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
- mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
- mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
- mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
- mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
- mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cast.py +0 -55
- mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/cdist.py +0 -38
- mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/ceil.py +0 -37
- mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/celu.py +0 -39
- mindspore/ops/_op_impl/tbe/centralization.py +0 -39
- mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
- mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/concat.py +0 -40
- mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
- mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
- mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
- mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
- mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
- mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/cos.py +0 -37
- mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/cosh.py +0 -37
- mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
- mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cummin.py +0 -41
- mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
- mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
- mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
- mindspore/ops/_op_impl/tbe/diag.py +0 -38
- mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
- mindspore/ops/_op_impl/tbe/dilation.py +0 -40
- mindspore/ops/_op_impl/tbe/div.py +0 -41
- mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
- mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
- mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
- mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
- mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
- mindspore/ops/_op_impl/tbe/elu.py +0 -38
- mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/equal.py +0 -42
- mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/erf.py +0 -37
- mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfc.py +0 -37
- mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
- mindspore/ops/_op_impl/tbe/exp.py +0 -40
- mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
- mindspore/ops/_op_impl/tbe/expm1.py +0 -37
- mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
- mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/fill.py +0 -56
- mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/flatten.py +0 -48
- mindspore/ops/_op_impl/tbe/floor.py +0 -37
- mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
- mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
- mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
- mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
- mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
- mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
- mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
- mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/ger.py +0 -43
- mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/greater.py +0 -43
- mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
- mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
- mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
- mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
- mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
- mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
- mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
- mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/im2col.py +0 -42
- mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
- mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
- mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/inv.py +0 -38
- mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/invert.py +0 -37
- mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/iou.py +0 -38
- mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/is_close.py +0 -40
- mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
- mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
- mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
- mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
- mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
- mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
- mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/lerp.py +0 -38
- mindspore/ops/_op_impl/tbe/less.py +0 -41
- mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/log.py +0 -40
- mindspore/ops/_op_impl/tbe/log1p.py +0 -37
- mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
- mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
- mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
- mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
- mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/matmul.py +0 -53
- mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
- mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
- mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
- mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum.py +0 -39
- mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
- mindspore/ops/_op_impl/tbe/minimum.py +0 -40
- mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mish.py +0 -37
- mindspore/ops/_op_impl/tbe/mod.py +0 -41
- mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/mul.py +0 -37
- mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
- mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
- mindspore/ops/_op_impl/tbe/neg.py +0 -39
- mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
- mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
- mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
- mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
- mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
- mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/pack.py +0 -58
- mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
- mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
- mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/pdist.py +0 -36
- mindspore/ops/_op_impl/tbe/pooling.py +0 -46
- mindspore/ops/_op_impl/tbe/population_count.py +0 -38
- mindspore/ops/_op_impl/tbe/pow.py +0 -41
- mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/prelu.py +0 -37
- mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/range.py +0 -39
- mindspore/ops/_op_impl/tbe/real_div.py +0 -38
- mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
- mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
- mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
- mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
- mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6.py +0 -38
- mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/renorm.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
- mindspore/ops/_op_impl/tbe/rint.py +0 -37
- mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roll.py +0 -42
- mindspore/ops/_op_impl/tbe/round.py +0 -38
- mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
- mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
- mindspore/ops/_op_impl/tbe/select.py +0 -38
- mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/selu.py +0 -39
- mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sgd.py +0 -62
- mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sign.py +0 -38
- mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/sin.py +0 -37
- mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sinh.py +0 -37
- mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/slice.py +0 -58
- mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
- mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax.py +0 -37
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
- mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/softplus.py +0 -37
- mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softsign.py +0 -37
- mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sort.py +0 -38
- mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/split_d.py +0 -38
- mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/split_v.py +0 -39
- mindspore/ops/_op_impl/tbe/splitv.py +0 -39
- mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/square.py +0 -38
- mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
- mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
- mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
- mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
- mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
- mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
- mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
- mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
- mindspore/ops/_op_impl/tbe/sub.py +0 -39
- mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tan.py +0 -38
- mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh.py +0 -37
- mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
- mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
- mindspore/ops/_op_impl/tbe/tile.py +0 -37
- mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
- mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
- mindspore/ops/_op_impl/tbe/transpose.py +0 -60
- mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
- mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
- mindspore/ops/_op_impl/tbe/trunc.py +0 -39
- mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/unpack.py +0 -38
- mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
- mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
- mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
- mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
- mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
- mindspore/ops/_tracefunc.py +0 -241
- mindspore/ops/arg_dtype_cast.py +0 -54
- mindspore/rewrite/api/tree_node_helper.py +0 -60
- mindspore/rewrite/ast_creator_register.py +0 -37
- mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
- mindspore/rewrite/namespace.py +0 -53
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -24,13 +24,13 @@ from mindspore.ops import operations as P
|
|
|
24
24
|
from mindspore.ops.composite import base
|
|
25
25
|
from mindspore.ops._primitive_cache import _get_cache_prim
|
|
26
26
|
from mindspore.ops.operations._inner_ops import TensorCopySlices, SliceGetItem, \
|
|
27
|
-
TopTypeof,
|
|
27
|
+
TopTypeof, IsParameter, GetitemTensorIndexInfo, SetitemTensorIndexInfo, \
|
|
28
28
|
SelectView, CopyWithSlice
|
|
29
|
+
from mindspore.ops.operations._sequence_ops import TensorToTuple, TensorToScalar, TupleToTensor
|
|
29
30
|
from mindspore.common import dtype as mstype
|
|
30
31
|
from mindspore.common._register_for_tensor import tensor_operator_registry
|
|
31
32
|
from mindspore.common.initializer import Zero
|
|
32
|
-
from mindspore.common import Tensor, CSRTensor, COOTensor
|
|
33
|
-
from mindspore.common import mutable
|
|
33
|
+
from mindspore.common import Tensor, CSRTensor, COOTensor, mutable
|
|
34
34
|
from mindspore import ops
|
|
35
35
|
from mindspore.ops.primitive import _primexpr
|
|
36
36
|
from mindspore import _checkparam as validator
|
|
@@ -317,24 +317,25 @@ def tensor_item(data, *args):
|
|
|
317
317
|
# transform a.item(tuple(int)) -> a.item(int1,int2...intN)
|
|
318
318
|
if data.ndim == 0:
|
|
319
319
|
_check_scalar_tensor_args(args)
|
|
320
|
-
return
|
|
320
|
+
return TensorToScalar()(data)
|
|
321
321
|
if len(args) == 1 and isinstance(args[0], tuple):
|
|
322
322
|
args = args[0]
|
|
323
323
|
|
|
324
324
|
args_types = hyper_map(F.typeof, args)
|
|
325
325
|
if not args or const_utils.judge_index_type(args_types[0], mstype.type_none):
|
|
326
326
|
if data.shape == (1,):
|
|
327
|
-
return
|
|
327
|
+
return TensorToScalar()(data[0])
|
|
328
328
|
const_utils.raise_value_error("Can only convert an array of size 1 to a Python scalar")
|
|
329
329
|
|
|
330
330
|
if not const_utils.judge_indexes_types(args_types, mstype.int64):
|
|
331
331
|
const_utils.raise_type_error("The index object cannot be interpreted as an integer")
|
|
332
332
|
|
|
333
333
|
if len(args) == data.ndim:
|
|
334
|
-
return
|
|
334
|
+
return tensor_index_by_tuple(data, args)
|
|
335
335
|
if len(args) > 1:
|
|
336
336
|
const_utils.raise_value_error("Incorrect number of indices for array")
|
|
337
|
-
|
|
337
|
+
output = _tensor_index_by_integer(F.reshape(data, (-1,)), args[0])
|
|
338
|
+
return TensorToScalar()(output)
|
|
338
339
|
|
|
339
340
|
|
|
340
341
|
def tensor_itemset(data, *args):
|
|
@@ -521,24 +522,45 @@ def _expand_data_dims(data, tuple_index):
|
|
|
521
522
|
return data, tuple_index_new
|
|
522
523
|
|
|
523
524
|
|
|
524
|
-
def
|
|
525
|
-
"""convert
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
525
|
+
def _convert_list_index_to_tensor(list_index):
|
|
526
|
+
"""convert list to tensor"""
|
|
527
|
+
has_bool = False
|
|
528
|
+
has_int = False
|
|
529
|
+
has_no_bool_int = False
|
|
530
|
+
for idx in list_index:
|
|
531
|
+
if isinstance(idx, bool):
|
|
532
|
+
has_bool = True
|
|
533
|
+
elif isinstance(idx, int):
|
|
534
|
+
has_int = True
|
|
535
|
+
else:
|
|
536
|
+
has_no_bool_int = True
|
|
537
|
+
|
|
538
|
+
all_bool = has_bool and not has_int and not has_no_bool_int
|
|
539
|
+
all_int = has_int and not has_bool and not has_no_bool_int
|
|
540
|
+
all_bool_or_int = not has_no_bool_int
|
|
541
|
+
|
|
542
|
+
if all_int:
|
|
543
|
+
index_tensor = TupleToTensor()(tuple(list_index), mstype.int64)
|
|
544
|
+
return index_tensor
|
|
545
|
+
|
|
546
|
+
|
|
547
|
+
if all_bool:
|
|
548
|
+
index_tensor = TupleToTensor()(tuple(list_index), mstype.bool_)
|
|
549
|
+
return index_tensor
|
|
550
|
+
|
|
551
|
+
# convert bool to int if index is mixture of (bool, int)
|
|
552
|
+
if all_bool_or_int:
|
|
553
|
+
new_index = []
|
|
554
|
+
for idx in list_index:
|
|
555
|
+
if isinstance(idx, bool):
|
|
556
|
+
new_idx = int(idx)
|
|
557
|
+
new_index.append(new_idx)
|
|
558
|
+
else:
|
|
559
|
+
new_index.append(idx)
|
|
560
|
+
index_tensor = TupleToTensor()(tuple(new_index), mstype.int64)
|
|
561
|
+
return index_tensor
|
|
562
|
+
|
|
563
|
+
return None
|
|
542
564
|
|
|
543
565
|
|
|
544
566
|
class _TensorIndexGetitem(base.TensorIndexGetitem_):
|
|
@@ -564,26 +586,6 @@ def tensor_index_by_slice(data, slice_index):
|
|
|
564
586
|
return _tensor_index_getitem(data, slice_index)
|
|
565
587
|
|
|
566
588
|
|
|
567
|
-
def get_stride_info_from_slice(data, slice_index):
|
|
568
|
-
"""get the stride info from slice index"""
|
|
569
|
-
data_shape = F.dyn_shape(data)
|
|
570
|
-
begin_strides, end_strides, step_strides = [], [], []
|
|
571
|
-
start, stop, step = get_slice_stride(slice_index, data_shape[0])
|
|
572
|
-
if start.ndim > 0:
|
|
573
|
-
start = start.item()
|
|
574
|
-
if stop.ndim > 0:
|
|
575
|
-
stop = stop.item()
|
|
576
|
-
if step.ndim > 0:
|
|
577
|
-
step = step.item()
|
|
578
|
-
begin_strides.append(start)
|
|
579
|
-
end_strides.append(stop)
|
|
580
|
-
step_strides.append(step)
|
|
581
|
-
begin_tensor = stack(begin_strides)
|
|
582
|
-
end_tensor = stack(end_strides)
|
|
583
|
-
step_tensor = stack(step_strides)
|
|
584
|
-
return begin_tensor, end_tensor, step_tensor
|
|
585
|
-
|
|
586
|
-
|
|
587
589
|
def tensor_index_by_number(data, number_index):
|
|
588
590
|
"""Tensor getitem by a Number which may be integer/float/bool value"""
|
|
589
591
|
if isinstance(number_index, bool):
|
|
@@ -607,31 +609,18 @@ def _tensor_index_by_bool(data, bool_value):
|
|
|
607
609
|
return output
|
|
608
610
|
|
|
609
611
|
|
|
610
|
-
def get_stride_info_from_integer(
|
|
612
|
+
def get_stride_info_from_integer(int_index):
|
|
611
613
|
"""Convert integer to slice"""
|
|
612
|
-
begin_strides =
|
|
613
|
-
end_strides =
|
|
614
|
-
step_strides =
|
|
615
|
-
|
|
616
|
-
end_tensor = stack(end_strides)
|
|
617
|
-
step_tensor = stack(step_strides)
|
|
618
|
-
return begin_tensor, end_tensor, step_tensor
|
|
614
|
+
begin_strides = (int_index,)
|
|
615
|
+
end_strides = (int_index + 1,)
|
|
616
|
+
step_strides = (1,)
|
|
617
|
+
return begin_strides, end_strides, step_strides
|
|
619
618
|
|
|
620
619
|
|
|
621
620
|
def _tensor_index_by_integer(data, int_index):
|
|
622
621
|
"""Tensor getitem by a single integer number"""
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
tensor_index = _scalar_to_tensor(int_index)
|
|
626
|
-
begin_strides, end_strides, step_strides = get_stride_info_from_integer(tensor_index)
|
|
627
|
-
else:
|
|
628
|
-
if not data_shape:
|
|
629
|
-
const_utils.raise_type_error("Cannot iterate over a scalar tensor.")
|
|
630
|
-
if data.ndim < 1 or data.ndim > 8:
|
|
631
|
-
const_utils.raise_value_error("Expect Tensor to have dimension between 1 and 8.")
|
|
632
|
-
transformed_number = const_utils.check_range(int_index, data_shape[0])
|
|
633
|
-
begin_strides, end_strides, step_strides = \
|
|
634
|
-
const_utils.get_stride_info_from_integer(data_shape, transformed_number)
|
|
622
|
+
begin_strides, end_strides, step_strides = get_stride_info_from_integer(int_index)
|
|
623
|
+
|
|
635
624
|
shrink_axis_mask = 1
|
|
636
625
|
begin_mask = 0
|
|
637
626
|
end_mask = 0
|
|
@@ -664,6 +653,7 @@ def tensor_index_by_tensor(data, tensor_index):
|
|
|
664
653
|
if not F.is_sequence_value_unknown(F.shape(data)):
|
|
665
654
|
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
666
655
|
if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Int):
|
|
656
|
+
tensor_index = F.select(tensor_index < 0, tensor_index + F.shape(data)[0], tensor_index)
|
|
667
657
|
return F.gather(data, tensor_index, 0)
|
|
668
658
|
if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Bool):
|
|
669
659
|
return tensor_index_by_bool_tensor(data, tensor_index)
|
|
@@ -676,27 +666,23 @@ def tensor_index_by_tensor(data, tensor_index):
|
|
|
676
666
|
def tensor_index_by_list(data, list_index):
|
|
677
667
|
"""Tensor getitem by list of int and bool"""
|
|
678
668
|
min_data_dim, max_data_dim = 1, 8
|
|
679
|
-
|
|
669
|
+
if F.isconstant(data.ndim):
|
|
670
|
+
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
680
671
|
|
|
681
672
|
data_shape = F.shape(data)
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
list_index, data_shape[0])
|
|
696
|
-
if tensor_index is False:
|
|
697
|
-
const_utils.raise_index_error(
|
|
698
|
-
"When tensor is indexed by list, the list can't be empty.")
|
|
699
|
-
return F.gather(data, tensor_index, 0)
|
|
673
|
+
if F.isconstant(data_shape[0]) and all(isinstance(i, bool) for i in list_index):
|
|
674
|
+
if data_shape[0] != len(list_index):
|
|
675
|
+
raise IndexError(
|
|
676
|
+
f'dimension is {data_shape[0]} but corresponding boolean dimension is {len(list_index)}')
|
|
677
|
+
tensor_index = Tensor(list_index).nonzero()
|
|
678
|
+
return F.gather_nd(data, tensor_index)
|
|
679
|
+
|
|
680
|
+
if not list_index:
|
|
681
|
+
const_utils.raise_index_error("When tensor is indexed by list, the list can't be empty.")
|
|
682
|
+
|
|
683
|
+
index_tensor = _convert_list_index_to_tensor(list_index)
|
|
684
|
+
if index_tensor is not None:
|
|
685
|
+
return tensor_index_by_tensor(data, index_tensor)
|
|
700
686
|
|
|
701
687
|
tuple_index_new = ()
|
|
702
688
|
for index in list_index:
|
|
@@ -704,16 +690,6 @@ def tensor_index_by_list(data, list_index):
|
|
|
704
690
|
return tensor_index_by_tuple(data, tuple_index_new)
|
|
705
691
|
|
|
706
692
|
|
|
707
|
-
def convert_tupleslice_to_tensor(tuple_index):
|
|
708
|
-
"""convert mutable scalar in slice to tensor"""
|
|
709
|
-
new_tuple_index = []
|
|
710
|
-
for item in tuple_index:
|
|
711
|
-
if isinstance(item, slice):
|
|
712
|
-
item = convert_variable_to_tensor_slice(item)
|
|
713
|
-
new_tuple_index.append(item)
|
|
714
|
-
return tuple(new_tuple_index)
|
|
715
|
-
|
|
716
|
-
|
|
717
693
|
def judge_tuple_index_dim_check_error(index_dim, data_dim):
|
|
718
694
|
"""raise IndexError when tuple_index's dim is invalid"""
|
|
719
695
|
if index_dim > data_dim:
|
|
@@ -721,29 +697,6 @@ def judge_tuple_index_dim_check_error(index_dim, data_dim):
|
|
|
721
697
|
f"dim of index:{index_dim}, dim of data:{data_dim}")
|
|
722
698
|
|
|
723
699
|
|
|
724
|
-
class _HandleEmptySlice(base.HandleEmptySlice_):
|
|
725
|
-
"""
|
|
726
|
-
Getting item of Tensor.
|
|
727
|
-
|
|
728
|
-
Args:
|
|
729
|
-
data (Tensor): A tuple to be sliced.
|
|
730
|
-
index: Index of tensor.
|
|
731
|
-
|
|
732
|
-
Returns:
|
|
733
|
-
Type is the same as the element type of data.
|
|
734
|
-
"""
|
|
735
|
-
|
|
736
|
-
def __init__(self, name):
|
|
737
|
-
"""Initialize _HandleEmptySlice."""
|
|
738
|
-
base.HandleEmptySlice_.__init__(self, name)
|
|
739
|
-
|
|
740
|
-
def __call__(self, *args):
|
|
741
|
-
pass
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
_handle_empty_slice = _HandleEmptySlice('handle_zero_tuple_index')
|
|
745
|
-
|
|
746
|
-
|
|
747
700
|
def judge_tuple_index_dim(data, tuple_index):
|
|
748
701
|
"""Judge whether tuple_index's dim is valid"""
|
|
749
702
|
data_dim = data.ndim
|
|
@@ -756,50 +709,20 @@ def judge_tuple_index_dim(data, tuple_index):
|
|
|
756
709
|
judge_tuple_index_dim_check_error(index_dim, data_dim)
|
|
757
710
|
|
|
758
711
|
|
|
759
|
-
def judge_simple_tuple_index(data, tuple_index):
|
|
760
|
-
"""Judge whether tuple_index is simple index, which not rollback to cpu ops."""
|
|
761
|
-
op_name = const_utils.TENSOR_GETITEM
|
|
762
|
-
indexes_types = hyper_map(toptypeof, tuple_index)
|
|
763
|
-
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
|
|
764
|
-
return F.isconstant(tuple_index) and contain_type == const_utils.ALL_BASIC \
|
|
765
|
-
and F.is_sequence_value_unknown(F.shape(data)) and F.isconstant(F.rank(data))
|
|
766
|
-
|
|
767
|
-
|
|
768
712
|
def tensor_index_by_tuple(data, tuple_index):
|
|
769
713
|
"""Tensor getitem by tuple of various types with None"""
|
|
770
714
|
if not tuple_index:
|
|
771
715
|
return data
|
|
772
|
-
if judge_simple_tuple_index(data, tuple_index):
|
|
773
|
-
tuple_index = convert_tupleslice_to_tensor(tuple_index)
|
|
774
|
-
op_name = const_utils.TENSOR_GETITEM
|
|
775
|
-
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
|
776
|
-
min_data_dim, max_data_dim = 1, 8
|
|
777
|
-
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
778
|
-
return _tensor_getitem_by_tuple_slice(data, tuple_index)
|
|
779
716
|
|
|
780
717
|
if not F.is_sequence_value_unknown(F.shape(data)):
|
|
781
718
|
judge_tuple_index_dim(data, tuple_index)
|
|
782
719
|
tuple_index, zero_index, non_zero_shapes = _handle_bool_tensor(tuple_index)
|
|
783
720
|
for non_zero_shape in non_zero_shapes:
|
|
784
|
-
if
|
|
721
|
+
if 0 in non_zero_shape:
|
|
785
722
|
tuple_index = zero_index
|
|
786
723
|
break
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
if 0 in stub_zero_dim_tensor.shape:
|
|
790
|
-
return F.fill(data.dtype, stub_zero_dim_tensor.shape, 0)
|
|
791
|
-
has_tensor_index = False
|
|
792
|
-
for i in tuple_index:
|
|
793
|
-
if isinstance(i, Tensor):
|
|
794
|
-
has_tensor_index = True
|
|
795
|
-
break
|
|
796
|
-
empty_broadcast_data_shape = False
|
|
797
|
-
_broadcast_data_shape = _handle_scalar_tensor_index(data, tuple_index)
|
|
798
|
-
if has_tensor_index and isinstance(_broadcast_data_shape, Tensor) and _broadcast_data_shape == Tensor([0]):
|
|
799
|
-
empty_broadcast_data_shape = True
|
|
800
|
-
if has_tensor_index and isinstance(_broadcast_data_shape, tuple) and not _broadcast_data_shape:
|
|
801
|
-
empty_broadcast_data_shape = True
|
|
802
|
-
return _tensor_index_getitem(data, tuple_index, empty_broadcast_data_shape)
|
|
724
|
+
|
|
725
|
+
return _tensor_index_getitem(data, tuple_index)
|
|
803
726
|
|
|
804
727
|
|
|
805
728
|
def get_slice_stride(slice_index, dim_size):
|
|
@@ -809,20 +732,20 @@ def get_slice_stride(slice_index, dim_size):
|
|
|
809
732
|
step = slice_get_item(slice_index, "step")
|
|
810
733
|
|
|
811
734
|
if start is None:
|
|
812
|
-
start =
|
|
735
|
+
start = 0
|
|
813
736
|
if stop is None:
|
|
814
737
|
stop = dim_size
|
|
815
738
|
if step is None:
|
|
816
|
-
step =
|
|
739
|
+
step = 1
|
|
817
740
|
|
|
818
|
-
if
|
|
819
|
-
start =
|
|
741
|
+
if isinstance(start, Tensor):
|
|
742
|
+
start = int(start)
|
|
820
743
|
|
|
821
|
-
if
|
|
822
|
-
stop =
|
|
744
|
+
if isinstance(stop, Tensor):
|
|
745
|
+
stop = int(stop)
|
|
823
746
|
|
|
824
|
-
if
|
|
825
|
-
step =
|
|
747
|
+
if isinstance(step, Tensor):
|
|
748
|
+
step = int(step)
|
|
826
749
|
|
|
827
750
|
return start, stop, step
|
|
828
751
|
|
|
@@ -841,190 +764,6 @@ def cal_tuple_slice_mask(data_shape, tuple_index):
|
|
|
841
764
|
return begin_mask, end_mask
|
|
842
765
|
|
|
843
766
|
|
|
844
|
-
def _get_stride_info_from_tuple(data, tuple_index):
|
|
845
|
-
"""get the stride info from tuple"""
|
|
846
|
-
data_shape = F.dyn_shape(data)
|
|
847
|
-
begin_strides, end_strides, step_strides = [], [], []
|
|
848
|
-
tuple_index_len = len(tuple_index)
|
|
849
|
-
data_dim = data.ndim
|
|
850
|
-
shrink_axis, index_count, ellipsis_count = 0, 0, 0
|
|
851
|
-
for item in range(data_dim):
|
|
852
|
-
if item >= tuple_index_len or item >= data_dim:
|
|
853
|
-
break
|
|
854
|
-
index = tuple_index[item]
|
|
855
|
-
dim_size = data_shape[item]
|
|
856
|
-
if isinstance(index, slice):
|
|
857
|
-
start, stop, step = get_slice_stride(index, dim_size)
|
|
858
|
-
begin_strides.append(start)
|
|
859
|
-
end_strides.append(stop)
|
|
860
|
-
step_strides.append(step)
|
|
861
|
-
index_count = index_count + 1
|
|
862
|
-
elif isinstance(index, int):
|
|
863
|
-
int_tensor = _scalar_to_tensor(index)
|
|
864
|
-
begin_strides.append(int_tensor)
|
|
865
|
-
end_strides.append(int_tensor + const_utils.make_tensor(1))
|
|
866
|
-
step_strides.append(const_utils.make_tensor(1))
|
|
867
|
-
shrink_axis = shrink_axis + (2 ** index_count)
|
|
868
|
-
index_count = index_count + 1
|
|
869
|
-
elif index is ...:
|
|
870
|
-
ellipsis_count = ellipsis_count + 1
|
|
871
|
-
if ellipsis_count > 1:
|
|
872
|
-
const_utils.raise_value_error("An index can have only one ellipsis (...)")
|
|
873
|
-
ellipsis_range_size = data_dim - tuple_index_len + 1
|
|
874
|
-
begin_strides.extend([const_utils.make_tensor(0)] * ellipsis_range_size)
|
|
875
|
-
end_strides.extend(
|
|
876
|
-
[shape for shape in data_shape[index_count: index_count + ellipsis_range_size]])
|
|
877
|
-
step_strides.extend([const_utils.make_tensor(1)] * ellipsis_range_size)
|
|
878
|
-
index_count = index_count + ellipsis_range_size
|
|
879
|
-
else:
|
|
880
|
-
exp_msg = const_utils.gen_exception_msg("Not supported index data type, got {}, type is {}", index,
|
|
881
|
-
type(index))
|
|
882
|
-
const_utils.raise_index_error(exp_msg)
|
|
883
|
-
begin_tensor = stack(begin_strides)
|
|
884
|
-
end_tensor = stack(end_strides)
|
|
885
|
-
step_tensor = stack(step_strides)
|
|
886
|
-
strides_v = {
|
|
887
|
-
'begin': begin_tensor,
|
|
888
|
-
'end': end_tensor,
|
|
889
|
-
'step': step_tensor
|
|
890
|
-
}
|
|
891
|
-
return strides_v, shrink_axis
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
def _tensor_getitem_by_tuple_slice(data, tuple_index):
|
|
895
|
-
"""Tensor getitem by a tuple of slice"""
|
|
896
|
-
data_shape = F.shape(data)
|
|
897
|
-
is_dynamic = F.is_sequence_value_unknown(data_shape)
|
|
898
|
-
for item in tuple_index:
|
|
899
|
-
if isinstance(item, slice):
|
|
900
|
-
is_dynamic = is_dynamic or isinstance(slice_get_item(item, "start"), Tensor) \
|
|
901
|
-
or isinstance(slice_get_item(item, "stop"), Tensor) \
|
|
902
|
-
or isinstance(slice_get_item(item, "step"), Tensor)
|
|
903
|
-
|
|
904
|
-
strides_v = {}
|
|
905
|
-
shrink_axis_mask = 0
|
|
906
|
-
if not is_dynamic:
|
|
907
|
-
strides_v, shrink_axis_mask = const_utils.get_stride_info_from_tuple(
|
|
908
|
-
data_shape, tuple_index)
|
|
909
|
-
else:
|
|
910
|
-
strides_v, shrink_axis_mask = _get_stride_info_from_tuple(
|
|
911
|
-
data, tuple_index)
|
|
912
|
-
begin_mask, end_mask = cal_tuple_slice_mask(data_shape, tuple_index)
|
|
913
|
-
begin_v = strides_v['begin']
|
|
914
|
-
end_v = strides_v['end']
|
|
915
|
-
step_v = strides_v['step']
|
|
916
|
-
return strided_slice(data, begin_v, end_v, step_v, begin_mask, end_mask, 0, 0, shrink_axis_mask)
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
@_primexpr
|
|
920
|
-
def _tensor_getitem_by_tuple_parse_bool_tensor_index(index, tuple_index_new, tensor_indexes,
|
|
921
|
-
tensor_positions_new):
|
|
922
|
-
""" parse index of bool tensor type """
|
|
923
|
-
indices = index.nonzero()
|
|
924
|
-
if indices.shape[0] == 0:
|
|
925
|
-
return None, tensor_indexes, tensor_positions_new
|
|
926
|
-
indices = F.cast(indices, mstype.int64)
|
|
927
|
-
indices = indices.T
|
|
928
|
-
for sub_index in indices:
|
|
929
|
-
tensor_positions_new.append(len(tuple_index_new))
|
|
930
|
-
tuple_index_new += (sub_index,)
|
|
931
|
-
tensor_indexes.append(sub_index)
|
|
932
|
-
return tuple_index_new, tensor_indexes, tensor_positions_new
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
def _tensor_getitem_by_tuple_parse_tensor_index(index, tuple_index_new, tensor_indexes, tensor_positions_new):
|
|
936
|
-
""" parse index of tensor type """
|
|
937
|
-
if F.dtype(index) in mstype.int_type:
|
|
938
|
-
tensor_index = F.cast(index, mstype.int64)
|
|
939
|
-
tensor_positions_new.append(len(tuple_index_new))
|
|
940
|
-
tuple_index_new += (tensor_index,)
|
|
941
|
-
tensor_indexes.append(tensor_index)
|
|
942
|
-
elif F.dtype(index) == mstype.bool_:
|
|
943
|
-
return _tensor_getitem_by_tuple_parse_bool_tensor_index(index, tuple_index_new, tensor_indexes,
|
|
944
|
-
tensor_positions_new)
|
|
945
|
-
else:
|
|
946
|
-
exp_msg = const_utils.gen_exception_msg(
|
|
947
|
-
"The tensor element in tuple index must be int or bool type, but got {}.", F.dtype(index))
|
|
948
|
-
const_utils.raise_index_error(exp_msg)
|
|
949
|
-
return tuple_index_new, tensor_indexes, tensor_positions_new
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
def _tensor_getitem_by_tuple(data, tuple_index, op_name):
|
|
953
|
-
"""Tensor getitem by a tuple of mixed tensor."""
|
|
954
|
-
slice_is_tensor = False
|
|
955
|
-
for item in tuple_index:
|
|
956
|
-
if isinstance(item, slice):
|
|
957
|
-
slice_is_tensor = isinstance(slice_get_item(item, "start"), Tensor) \
|
|
958
|
-
or isinstance(slice_get_item(item, "stop"), Tensor) \
|
|
959
|
-
or isinstance(slice_get_item(item, "step"), Tensor)
|
|
960
|
-
if slice_is_tensor:
|
|
961
|
-
const_utils.raise_index_error("Not supported when slice has tensor")
|
|
962
|
-
|
|
963
|
-
indexes_types = hyper_map(toptypeof, tuple_index)
|
|
964
|
-
slice_positions, _, _, int_positions, _, tensor_positions, sequence_positions = \
|
|
965
|
-
const_utils.get_pos_of_indexes_types(indexes_types, op_name)
|
|
966
|
-
data_shape = F.shape(data)
|
|
967
|
-
tensor_indexes, slice_indexes = [], []
|
|
968
|
-
tuple_index_new, slice_shapes = (), ()
|
|
969
|
-
slice_positions_new, tensor_positions_new = [], []
|
|
970
|
-
for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
|
|
971
|
-
if i in int_positions:
|
|
972
|
-
int_index = const_utils.check_range(index, dim_size)
|
|
973
|
-
tensor_index = F.scalar_to_tensor(int_index, mstype.int64)
|
|
974
|
-
if F.is_sequence_value_unknown(data_shape):
|
|
975
|
-
tensor_index = _scalar_to_tensor(int_index)
|
|
976
|
-
tensor_index = F.cast(tensor_index, mstype.int64)
|
|
977
|
-
tensor_positions_new.append(len(tuple_index_new))
|
|
978
|
-
tuple_index_new += (tensor_index,)
|
|
979
|
-
tensor_indexes.append(tensor_index)
|
|
980
|
-
elif i in sequence_positions:
|
|
981
|
-
tensor_index = const_utils.sequence_to_index(index, dim_size)
|
|
982
|
-
if tensor_index is False:
|
|
983
|
-
const_utils.raise_index_error("The sequence element(tuple/list) in tuple index can't be empty.")
|
|
984
|
-
tensor_positions_new.append(len(tuple_index_new))
|
|
985
|
-
tuple_index_new += (tensor_index,)
|
|
986
|
-
tensor_indexes.append(tensor_index)
|
|
987
|
-
elif i in tensor_positions:
|
|
988
|
-
tuple_index_new, tensor_indexes, tensor_positions_new = \
|
|
989
|
-
_tensor_getitem_by_tuple_parse_tensor_index(index, tuple_index_new,
|
|
990
|
-
tensor_indexes, tensor_positions_new)
|
|
991
|
-
if tuple_index_new is None:
|
|
992
|
-
return Tensor([])
|
|
993
|
-
elif i in slice_positions:
|
|
994
|
-
slice_ele_list_index = const_utils.transform_slice_to_ele_list(index, dim_size)
|
|
995
|
-
slice_shapes += (len(slice_ele_list_index),)
|
|
996
|
-
slice_positions_new.append(len(tuple_index_new))
|
|
997
|
-
tuple_index_new += (slice_ele_list_index,)
|
|
998
|
-
slice_indexes.append(slice_ele_list_index)
|
|
999
|
-
tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
|
|
1000
|
-
broadcast_shape, index_tensor_new_shape, final_shape, fancy_position = \
|
|
1001
|
-
const_utils.generate_index_info_from_tuple_of_mixed_tensors(tensor_positions_new, tensor_indexes_shapes,
|
|
1002
|
-
slice_shapes, op_name)
|
|
1003
|
-
|
|
1004
|
-
tuple_index_len = len(tuple_index)
|
|
1005
|
-
if 0 in final_shape + data_shape:
|
|
1006
|
-
if tuple_index_len < len(data_shape):
|
|
1007
|
-
final_shape = final_shape + data_shape[tuple_index_len:]
|
|
1008
|
-
return const_utils.make_tensor([], data.dtype, final_shape)
|
|
1009
|
-
|
|
1010
|
-
final_index_tensors = []
|
|
1011
|
-
slice_cnt = 0
|
|
1012
|
-
for i, index in enumerate(tuple_index_new):
|
|
1013
|
-
if i in tensor_positions_new:
|
|
1014
|
-
transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape,
|
|
1015
|
-
index)
|
|
1016
|
-
final_index_tensors.append(transform_tensor)
|
|
1017
|
-
elif i in slice_positions_new:
|
|
1018
|
-
slice_index_tensor = convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape,
|
|
1019
|
-
slice_shapes, fancy_position)
|
|
1020
|
-
final_index_tensors.append(slice_index_tensor)
|
|
1021
|
-
slice_cnt += 1
|
|
1022
|
-
|
|
1023
|
-
indices = stack(final_index_tensors)
|
|
1024
|
-
result = F.gather_nd(data, indices)
|
|
1025
|
-
return result
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
767
|
def _generate_indices_from_tuple_of_tensor(tuple_index, op_name):
|
|
1029
768
|
"""Generate an indices tensor from a tuple of tensor."""
|
|
1030
769
|
indexes_types = hyper_map(F.dtype, tuple_index)
|
|
@@ -1116,8 +855,15 @@ def sequence_to_tensor(value, dtype):
|
|
|
1116
855
|
|
|
1117
856
|
if value_elements_type == const_utils.ALL_TENSOR:
|
|
1118
857
|
value = F.stack(value).astype(dtype)
|
|
1119
|
-
elif value_elements_type == const_utils.NO_TENSOR
|
|
1120
|
-
|
|
858
|
+
elif value_elements_type == const_utils.NO_TENSOR:
|
|
859
|
+
if isinstance(value, list):
|
|
860
|
+
value = tuple(value)
|
|
861
|
+
|
|
862
|
+
if dtype == mstype.float16:
|
|
863
|
+
value = TupleToTensor()(value, mstype.float32)
|
|
864
|
+
value = F.cast(value, dtype)
|
|
865
|
+
else:
|
|
866
|
+
value = TupleToTensor()(value, dtype)
|
|
1121
867
|
else:
|
|
1122
868
|
new_value = ()
|
|
1123
869
|
for ele in value:
|
|
@@ -1138,57 +884,31 @@ def _generate_updates_from_sequence(data, index, value, op_type):
|
|
|
1138
884
|
def _generate_updates_from_tensor(data, index, value, op_type):
|
|
1139
885
|
"""Generate an updates tensor from a tensor."""
|
|
1140
886
|
value = value.astype(data.dtype)
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type, True)
|
|
1145
|
-
updates = ops.broadcast_to(value, updates_shape)
|
|
1146
|
-
return updates
|
|
1147
|
-
updates_shape = const_utils.generate_updates_shape(data.shape, index.shape, op_type, False)
|
|
1148
|
-
need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value.shape)
|
|
1149
|
-
if need_broadcast:
|
|
1150
|
-
return _broadcast(updates_shape, value)
|
|
1151
|
-
return value
|
|
887
|
+
updates_shape = const_utils.generate_updates_shape(data.shape, index.shape, op_type)
|
|
888
|
+
updates = ops.broadcast_to(value, updates_shape)
|
|
889
|
+
return updates
|
|
1152
890
|
|
|
1153
891
|
|
|
1154
892
|
# Tensor getitem implementations are above this line, setitem implementations below.
|
|
1155
893
|
|
|
1156
|
-
def
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
return tensor_setitem_by_tensor_with_tensor(self, index, value)
|
|
1161
|
-
return tensor_setitem_by_tensor_with_sequence(self, index, value)
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
def tensor_setitem_by_tuple(self, index, value):
|
|
1165
|
-
index = convert_tupleslice_to_tensor(index)
|
|
1166
|
-
if isinstance(value, (int, float, bool)):
|
|
1167
|
-
index = format_tuple_indices(index)
|
|
1168
|
-
return tensor_setitem_by_tuple_with_number(self, index, value)
|
|
1169
|
-
if isinstance(value, Tensor):
|
|
1170
|
-
return tensor_setitem_by_tuple_with_tensor(self, index, value)
|
|
1171
|
-
return tensor_setitem_by_tuple_with_sequence(self, index, value)
|
|
1172
|
-
|
|
894
|
+
def _tensor_index_transfer(index, broadcast_shape, final_shape, new_shape):
|
|
895
|
+
"""Transform tuple index tensor to the required."""
|
|
896
|
+
if 0 in final_shape:
|
|
897
|
+
return F.fill(index.dtype, final_shape, 0)
|
|
1173
898
|
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
|
|
899
|
+
if broadcast_shape == ():
|
|
900
|
+
# broadcast_to () is not support on Ascend
|
|
901
|
+
item = index
|
|
902
|
+
else:
|
|
903
|
+
item = F.broadcast_to(index, broadcast_shape)
|
|
904
|
+
item = F.reshape(item, new_shape)
|
|
905
|
+
return F.broadcast_to(item, final_shape)
|
|
1180
906
|
|
|
1181
907
|
|
|
1182
|
-
def
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
x = F.broadcast_to(x, broadcast_shape)
|
|
1187
|
-
x = F.reshape(x, new_shape)
|
|
1188
|
-
x = F.broadcast_to(x, final_shape)
|
|
1189
|
-
return x
|
|
1190
|
-
item = _broadcast(broadcast_shape, x)
|
|
1191
|
-
return _broadcast(final_shape, F.reshape(item, new_shape))
|
|
908
|
+
def reshape_with_check(x, new_shape):
|
|
909
|
+
if isinstance(new_shape, Tensor):
|
|
910
|
+
new_shape = TensorToTuple()(new_shape)
|
|
911
|
+
return F.reshape(x, new_shape)
|
|
1192
912
|
|
|
1193
913
|
|
|
1194
914
|
class _TensorIndexSetitem(base.TensorIndexSetitem_):
|
|
@@ -1218,9 +938,10 @@ def tensor_setitem_by_slice(self, index, value):
|
|
|
1218
938
|
return self
|
|
1219
939
|
value = F.broadcast_to(value, value_shape)
|
|
1220
940
|
if not const_utils.is_ascend() and step == 1:
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
941
|
+
start = (start,)
|
|
942
|
+
stop = (stop,)
|
|
943
|
+
step = (step,)
|
|
944
|
+
return copy_slice(self, value, start, stop, step)
|
|
1224
945
|
return F.tensor_scatter_update(self, indices, value)
|
|
1225
946
|
|
|
1226
947
|
|
|
@@ -1236,14 +957,14 @@ def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
|
|
|
1236
957
|
"""Set a tensor item by an int tensor with a tensor."""
|
|
1237
958
|
if F.rank(index) == 0:
|
|
1238
959
|
index = F.expand_dims(index, -1)
|
|
1239
|
-
|
|
960
|
+
|
|
1240
961
|
data_shape = F.shape(data)
|
|
962
|
+
updates_shape = index.shape + data_shape[1:]
|
|
963
|
+
value = F.cast(value, F.dtype(data))
|
|
964
|
+
updates = ops.broadcast_to(value, updates_shape)
|
|
1241
965
|
first_val = data_shape[0]
|
|
1242
966
|
index = F.select(index < 0, index + first_val, index)
|
|
1243
967
|
index = F.expand_dims(index, -1)
|
|
1244
|
-
if F.rank(index) < 2:
|
|
1245
|
-
index = F.expand_dims(index, 0)
|
|
1246
|
-
updates = F.expand_dims(updates, 0)
|
|
1247
968
|
if is_parameter(data):
|
|
1248
969
|
F.scatter_nd_update(data, index, updates)
|
|
1249
970
|
return data
|
|
@@ -1255,8 +976,7 @@ def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
|
|
|
1255
976
|
index = index.reshape(const_utils.generate_padding_shape(index.shape, len(data.shape)))
|
|
1256
977
|
index = F.broadcast_to(index, data.shape)
|
|
1257
978
|
value = F.cast(value, F.dtype(data))
|
|
1258
|
-
|
|
1259
|
-
value = value.unsqueeze(-1)
|
|
979
|
+
value = value.reshape(const_utils.generate_padding_shape(value.shape, len(data.shape)))
|
|
1260
980
|
value = F.broadcast_to(value, data.shape)
|
|
1261
981
|
result = F.select(index, value, data)
|
|
1262
982
|
return result
|
|
@@ -1269,8 +989,6 @@ def tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
|
|
|
1269
989
|
if tensor_dtype == const_utils.INT_:
|
|
1270
990
|
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
|
|
1271
991
|
|
|
1272
|
-
if F.is_sequence_value_unknown(F.shape(data)):
|
|
1273
|
-
return tensor_setitem_by_tuple_with_tensor(data, (index,), value_tensor.astype(data.dtype))
|
|
1274
992
|
return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
|
|
1275
993
|
|
|
1276
994
|
|
|
@@ -1281,33 +999,8 @@ def tensor_setitem_by_tensor_with_number(data, index, value):
|
|
|
1281
999
|
|
|
1282
1000
|
def tensor_setitem_by_tensor_with_sequence(data, index, value):
|
|
1283
1001
|
"""Assigns the tensor by tensor with tuple value."""
|
|
1284
|
-
index_dtype = F.dtype(index)
|
|
1285
|
-
if index_dtype in (mstype.int32, mstype.int64):
|
|
1286
|
-
return _tensor_setitem_by_tensor_with_sequence(data, index, value)
|
|
1287
|
-
if index_dtype == mstype.bool_:
|
|
1288
|
-
return _tensor_setitem_by_bool_tensor_with_sequence(data, index, value)
|
|
1289
|
-
exp_msg = const_utils.gen_exception_msg("The tensor index must be int or bool type, but got {}.", index_dtype)
|
|
1290
|
-
const_utils.raise_index_error(exp_msg)
|
|
1291
|
-
return None
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
def _tensor_setitem_by_tensor_with_sequence(data, index, value):
|
|
1295
|
-
"""Set a tensor item by a tensor with a tuple."""
|
|
1296
|
-
updates = _generate_updates_from_sequence(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
|
|
1297
|
-
index = F.expand_dims(index, -1)
|
|
1298
|
-
return F.tensor_scatter_update(data, index, updates)
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
def _tensor_setitem_by_bool_tensor_with_sequence(data, index, value):
|
|
1302
|
-
"""Set a tensor item by a bool tensor with a tuple."""
|
|
1303
1002
|
value = sequence_to_tensor(value, F.dtype(data))
|
|
1304
|
-
return
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
def tensor_setitem_by_slice_with_number(data, input_slice, value):
|
|
1308
|
-
"""Givens a scalar assign to tensor by slice"""
|
|
1309
|
-
value = F.cast(value, F.dtype(data))
|
|
1310
|
-
return tensor_setitem_by_slice_with_tensor(data, input_slice, value)
|
|
1003
|
+
return tensor_setitem_by_tensor_with_tensor(data, index, value)
|
|
1311
1004
|
|
|
1312
1005
|
|
|
1313
1006
|
def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
|
|
@@ -1316,78 +1009,14 @@ def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
|
|
|
1316
1009
|
return tensor_setitem_by_tuple_with_tensor(data, tuple_index, value)
|
|
1317
1010
|
|
|
1318
1011
|
|
|
1319
|
-
def
|
|
1320
|
-
"""
|
|
1321
|
-
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
stop_tensor = stack((stop,))
|
|
1325
|
-
step_tensor = stack((step,))
|
|
1326
|
-
dim0_size = stop_tensor - start_tensor
|
|
1327
|
-
if dim0_size <= 0:
|
|
1328
|
-
return data
|
|
1329
|
-
if dim0_size >= data_shape[0]:
|
|
1330
|
-
dim0_size = data_shape[0:1]
|
|
1331
|
-
value_shape = P.Concat(-1)((dim0_size, data_shape[1:]))
|
|
1332
|
-
value = ops.broadcast_to(value, value_shape)
|
|
1333
|
-
return copy_slice(data, value.astype(data.dtype), start_tensor, stop_tensor, step_tensor)
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
|
|
1337
|
-
"""Assigns a tensor value to the tensor by slice."""
|
|
1338
|
-
result = None
|
|
1339
|
-
check_result = const_utils.check_tensor_setitem_index(input_slice)
|
|
1340
|
-
if check_result:
|
|
1341
|
-
data_shape = F.shape(data)
|
|
1342
|
-
step = const_utils.get_step_from_slice(input_slice)
|
|
1343
|
-
if step == 1 and not const_utils.is_ascend():
|
|
1344
|
-
if F.is_sequence_value_unknown(data_shape):
|
|
1345
|
-
return tensor_copy_slice_from_slice(data, input_slice, value)
|
|
1346
|
-
start, stop, step = const_utils.normalize_slice(input_slice, data.shape[0])
|
|
1347
|
-
dim0_size = stop - start
|
|
1348
|
-
if dim0_size <= 0:
|
|
1349
|
-
return data
|
|
1350
|
-
value_shape = (dim0_size,) + const_utils.tuple_slice(data.shape, 1, None)
|
|
1351
|
-
value = _broadcast(value_shape, value)
|
|
1352
|
-
return copy_slice(data, value.astype(data.dtype), (start,), (stop,), (step,))
|
|
1353
|
-
if F.is_sequence_value_unknown(data_shape):
|
|
1354
|
-
const_utils.raise_unimplemented_error(
|
|
1355
|
-
"Not supported to take the subscript of dynamic shape tensor slice setitem")
|
|
1356
|
-
indices = const_utils.slice2indices(input_slice, data_shape)
|
|
1357
|
-
if indices is False:
|
|
1358
|
-
return data
|
|
1359
|
-
value_shape = const_utils.tuple_slice(F.shape(indices), None, -1)
|
|
1360
|
-
value = _broadcast(value_shape, value)
|
|
1361
|
-
result = F.tensor_scatter_update(data, indices, value.astype(F.dtype(data)))
|
|
1362
|
-
return result
|
|
1012
|
+
def tensor_setitem_by_list(data, index, value):
|
|
1013
|
+
"""list indices will be converted to tuple or tensor based on its contents."""
|
|
1014
|
+
index_tensor = _convert_list_index_to_tensor(index)
|
|
1015
|
+
if index_tensor is not None:
|
|
1016
|
+
return tensor_setitem_by_tensor_with_tensor(data, index_tensor, value)
|
|
1363
1017
|
|
|
1018
|
+
return tensor_setitem_by_tuple_with_tensor(data, tuple(index), value)
|
|
1364
1019
|
|
|
1365
|
-
def tensor_setitem_by_slice_with_sequence(data, input_slice, value):
|
|
1366
|
-
"""Assigns a list/tuple value to the tensor by slice."""
|
|
1367
|
-
value = _generate_updates_from_sequence(data, input_slice, value, const_utils.SET_ITEM_BY_NON_TENSOR)
|
|
1368
|
-
return tensor_setitem_by_slice_with_tensor(data, input_slice, value)
|
|
1369
|
-
|
|
1370
|
-
|
|
1371
|
-
def tensor_copy_slice_from_tuple(data, tuple_index, value):
|
|
1372
|
-
"""using TensorCopySlices by fixed model tuple."""
|
|
1373
|
-
data_shape = F.dyn_shape(data)
|
|
1374
|
-
dim1_start, dim1_stop, _ = get_slice_stride(tuple_index[1], data_shape[1])
|
|
1375
|
-
if dim1_stop - dim1_start <= 0:
|
|
1376
|
-
return data
|
|
1377
|
-
dim0_start = _scalar_to_tensor(tuple_index[0])
|
|
1378
|
-
dim0_stop = dim0_start + const_utils.make_tensor(1)
|
|
1379
|
-
start = (dim0_start, dim1_start)
|
|
1380
|
-
stop = (dim0_stop, dim1_stop)
|
|
1381
|
-
step = (const_utils.make_tensor(1), const_utils.make_tensor(1))
|
|
1382
|
-
start_tensor = stack(start)
|
|
1383
|
-
stop_tensor = stack(stop)
|
|
1384
|
-
step_tensor = stack(step)
|
|
1385
|
-
dim1_size = stack((dim1_stop - dim1_start,))
|
|
1386
|
-
if dim1_size > data_shape[1]:
|
|
1387
|
-
dim1_size = data_shape[1:2]
|
|
1388
|
-
value_shape = P.Concat(-1)((dim1_size, data_shape[2:]))
|
|
1389
|
-
value = ops.broadcast_to(value, value_shape)
|
|
1390
|
-
return copy_slice(data, value.astype(data.dtype), start_tensor, stop_tensor, step_tensor)
|
|
1391
1020
|
|
|
1392
1021
|
|
|
1393
1022
|
class _PreSetitemByTuple(base.PreSetitemByTuple_):
|
|
@@ -1436,50 +1065,28 @@ class _HandleBoolTensor(base.HandleBoolTensor_):
|
|
|
1436
1065
|
_handle_bool_tensor = _HandleBoolTensor('handle_bool_tensor')
|
|
1437
1066
|
|
|
1438
1067
|
|
|
1439
|
-
class _HandleScalarTensorIndex(base.HandleScalarTensorIndex_):
|
|
1440
|
-
"""
|
|
1441
|
-
Getting item of Tensor.
|
|
1442
|
-
|
|
1443
|
-
Args:
|
|
1444
|
-
data (Tensor): A tuple to be sliced.
|
|
1445
|
-
index: Index of tensor.
|
|
1446
|
-
|
|
1447
|
-
Returns:
|
|
1448
|
-
Type is the same as the element type of data.
|
|
1449
|
-
"""
|
|
1450
|
-
|
|
1451
|
-
def __init__(self, name):
|
|
1452
|
-
"""Initialize _HandleBoolTensor."""
|
|
1453
|
-
base.HandleScalarTensorIndex_.__init__(self, name)
|
|
1454
|
-
|
|
1455
|
-
def __call__(self, *args):
|
|
1456
|
-
pass
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
_handle_scalar_tensor_index = _HandleScalarTensorIndex('handle_scalar_tensor_index')
|
|
1460
|
-
|
|
1461
|
-
|
|
1462
1068
|
def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
|
1463
1069
|
"""Assigns the tensor by tuple with tensor value."""
|
|
1464
1070
|
if const_utils.use_copy_slice(tuple_index) and not const_utils.is_ascend():
|
|
1465
|
-
if F.is_sequence_value_unknown(F.shape(data)):
|
|
1466
|
-
return tensor_copy_slice_from_tuple(data, tuple_index, value)
|
|
1467
1071
|
dim1_start, dim1_stop, _ = const_utils.normalize_slice(
|
|
1468
1072
|
tuple_index[1], data.shape[1])
|
|
1073
|
+
if isinstance(dim1_start, Tensor):
|
|
1074
|
+
dim1_start = int(dim1_start)
|
|
1075
|
+
if isinstance(dim1_stop, Tensor):
|
|
1076
|
+
dim1_stop = int(dim1_stop)
|
|
1469
1077
|
if dim1_stop - dim1_start <= 0:
|
|
1470
1078
|
return data
|
|
1471
1079
|
dim0_start = tuple_index[0] if tuple_index[0] >= 0 else tuple_index[0] + data.shape[0]
|
|
1472
1080
|
start = (dim0_start, dim1_start)
|
|
1473
1081
|
stop = (dim0_start + 1, dim1_stop)
|
|
1474
1082
|
step = (1, 1)
|
|
1475
|
-
value_shape = (dim1_stop - dim1_start,) +
|
|
1476
|
-
|
|
1477
|
-
value = _broadcast(value_shape, value)
|
|
1083
|
+
value_shape = (dim1_stop - dim1_start,) + data.shape[2:]
|
|
1084
|
+
value = F.broadcast_to(value, value_shape)
|
|
1478
1085
|
return copy_slice(data, value.astype(data.dtype), start, stop, step)
|
|
1479
1086
|
tuple_index, _, non_zero_shapes = _handle_bool_tensor(tuple_index)
|
|
1480
1087
|
|
|
1481
1088
|
for non_zero_shape in non_zero_shapes:
|
|
1482
|
-
if
|
|
1089
|
+
if 0 in non_zero_shape:
|
|
1483
1090
|
return data
|
|
1484
1091
|
value = value.astype(data.dtype)
|
|
1485
1092
|
special_index, tuple_index, new_value_shape, idx_advanced, _broadcast_data_shape \
|
|
@@ -1512,17 +1119,19 @@ def tensor_itemset_by_tuple_with_tensor(data, tuple_index, value):
|
|
|
1512
1119
|
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
|
1513
1120
|
|
|
1514
1121
|
if const_utils.use_copy_slice(tuple_index) and not const_utils.is_ascend():
|
|
1515
|
-
if F.is_sequence_value_unknown(F.shape(data)):
|
|
1516
|
-
return tensor_copy_slice_from_tuple(data, tuple_index, value)
|
|
1517
1122
|
dim1_start, dim1_stop, _ = const_utils.normalize_slice(tuple_index[1], data.shape[1])
|
|
1123
|
+
if isinstance(dim1_start, Tensor):
|
|
1124
|
+
dim1_start = int(dim1_start)
|
|
1125
|
+
if isinstance(dim1_stop, Tensor):
|
|
1126
|
+
dim1_stop = int(dim1_stop)
|
|
1518
1127
|
if dim1_stop - dim1_start <= 0:
|
|
1519
1128
|
return data
|
|
1520
1129
|
dim0_start = tuple_index[0] if tuple_index[0] >= 0 else tuple_index[0] + data.shape[0]
|
|
1521
1130
|
start = (dim0_start, dim1_start)
|
|
1522
1131
|
stop = (dim0_start + 1, dim1_stop)
|
|
1523
1132
|
step = (1, 1)
|
|
1524
|
-
value_shape = (dim1_stop - dim1_start,) +
|
|
1525
|
-
value =
|
|
1133
|
+
value_shape = (dim1_stop - dim1_start,) + data.shape[2:]
|
|
1134
|
+
value = F.broadcast_to(value, value_shape)
|
|
1526
1135
|
return copy_slice(data, value.astype(data.dtype), start, stop, step)
|
|
1527
1136
|
tuple_index, value, idx_advanced = remove_expanded_dims(tuple_index, F.shape(data), value)
|
|
1528
1137
|
|
|
@@ -1545,49 +1154,45 @@ def tensor_itemset_by_tuple_with_tensor(data, tuple_index, value):
|
|
|
1545
1154
|
|
|
1546
1155
|
|
|
1547
1156
|
def tensor_setitem_by_tuple_with_sequence(data, tuple_index, value):
|
|
1548
|
-
value =
|
|
1157
|
+
value = sequence_to_tensor(value, F.dtype(data))
|
|
1549
1158
|
return tensor_setitem_by_tuple_with_tensor(data, tuple_index, value)
|
|
1550
1159
|
|
|
1551
1160
|
|
|
1552
1161
|
def tensor_setitem_by_number_with_number(data, index, value):
|
|
1553
1162
|
"""Assigns the tensor by number with number value."""
|
|
1554
|
-
|
|
1555
|
-
|
|
1163
|
+
data_shape = F.shape(data)
|
|
1164
|
+
dim_size = data_shape[0]
|
|
1165
|
+
if index < 0:
|
|
1166
|
+
index += dim_size
|
|
1167
|
+
if index < -dim_size or index >= dim_size:
|
|
1168
|
+
raise IndexError(f'index {index} is out of bounds for axis 0 with size {dim_size}')
|
|
1169
|
+
index = F.cast(index, mstype.int64)
|
|
1170
|
+
index = F.reshape(index, (1, 1))
|
|
1171
|
+
|
|
1172
|
+
updates = F.cast(value, data.dtype)
|
|
1173
|
+
updates_shape = (1,) + data_shape[1:]
|
|
1174
|
+
updates = ops.broadcast_to(updates, updates_shape)
|
|
1175
|
+
|
|
1176
|
+
if is_parameter(data):
|
|
1177
|
+
F.scatter_nd_update(data, index, updates)
|
|
1178
|
+
return data
|
|
1179
|
+
return F.tensor_scatter_update(data, index, updates)
|
|
1556
1180
|
|
|
1557
1181
|
|
|
1558
1182
|
def tensor_setitem_by_number_with_sequence(data, index, value):
|
|
1559
1183
|
"""Assigns a list/tuple value to the tensor by slice."""
|
|
1560
|
-
value =
|
|
1184
|
+
value = sequence_to_tensor(value, F.dtype(data))
|
|
1561
1185
|
return tensor_setitem_by_number_with_tensor(data, index, value)
|
|
1562
1186
|
|
|
1563
1187
|
|
|
1564
1188
|
def tensor_setitem_by_number_with_tensor(data, index, value):
|
|
1565
|
-
|
|
1566
|
-
data_shape = F.shape(data)
|
|
1567
|
-
if F.is_sequence_value_unknown(data_shape):
|
|
1568
|
-
index = _scalar_to_tensor(index)
|
|
1569
|
-
index = F.expand_dims(index, -1)
|
|
1570
|
-
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value)
|
|
1571
|
-
|
|
1572
|
-
dim_size = data_shape[0]
|
|
1573
|
-
if index < -dim_size or index >= dim_size:
|
|
1574
|
-
raise IndexError(f'index {index} is out of bounds for axis 0 with size {dim_size}')
|
|
1575
|
-
index = const_utils.int_to_index(index, data_shape)
|
|
1576
|
-
value_shape = const_utils.tuple_slice(F.shape(index), None, -1)
|
|
1577
|
-
value = _broadcast(value_shape, value.astype(F.dtype(data)))
|
|
1578
|
-
if is_parameter(data):
|
|
1579
|
-
F.scatter_nd_update(data, index, value)
|
|
1580
|
-
return data
|
|
1581
|
-
return F.tensor_scatter_update(data, index, value)
|
|
1189
|
+
return tensor_setitem_by_number_with_number(data, index, value)
|
|
1582
1190
|
|
|
1583
1191
|
|
|
1584
1192
|
def tensor_setitem_by_ellipsis_with_number(data, value):
|
|
1585
1193
|
"""Assigns the tensor by ellipsis with number value."""
|
|
1586
1194
|
data_shape = F.shape(data)
|
|
1587
1195
|
data_dtype = F.dtype(data)
|
|
1588
|
-
if F.is_sequence_value_unknown(data_shape):
|
|
1589
|
-
value = F.cast(value, F.dtype(data))
|
|
1590
|
-
return tensor_setitem_by_ellipsis_with_tensor(data, value)
|
|
1591
1196
|
return F.fill(data_dtype, data_shape, value)
|
|
1592
1197
|
|
|
1593
1198
|
|
|
@@ -1597,21 +1202,16 @@ def tensor_setitem_by_ellipsis_with_tensor(data, value):
|
|
|
1597
1202
|
data_dtype = F.dtype(data)
|
|
1598
1203
|
value = value.astype(data_dtype)
|
|
1599
1204
|
|
|
1600
|
-
if F.is_sequence_value_unknown(data_shape):
|
|
1601
|
-
data_shape = F.dyn_shape(data)
|
|
1602
|
-
data = ops.broadcast_to(value, data_shape)
|
|
1603
|
-
return data
|
|
1604
1205
|
value_shape = F.shape(value)
|
|
1605
1206
|
source_shape = const_utils.get_source_shape(data_shape, value_shape)
|
|
1606
1207
|
value = F.reshape(value, source_shape)
|
|
1607
|
-
|
|
1608
|
-
data = F.cast(value, data_dtype)
|
|
1208
|
+
data = F.broadcast_to(value, data_shape)
|
|
1609
1209
|
return data
|
|
1610
1210
|
|
|
1611
1211
|
|
|
1612
1212
|
def tensor_setitem_by_ellipsis_with_sequence(data, value):
|
|
1613
1213
|
"""Assigns a list/tuple value to the tensor by ellipsis."""
|
|
1614
|
-
value =
|
|
1214
|
+
value = sequence_to_tensor(value, F.dtype(data))
|
|
1615
1215
|
return tensor_setitem_by_ellipsis_with_tensor(data, value)
|
|
1616
1216
|
|
|
1617
1217
|
|
|
@@ -1622,23 +1222,15 @@ def tensor_setitem_by_bool(data, index, value):
|
|
|
1622
1222
|
if not index:
|
|
1623
1223
|
data_shape = (0,) + data_shape
|
|
1624
1224
|
if isinstance(value, (list, tuple)):
|
|
1625
|
-
value =
|
|
1626
|
-
|
|
1627
|
-
value =
|
|
1628
|
-
|
|
1629
|
-
value = const_utils.make_tensor(value, mstype.float32)
|
|
1630
|
-
|
|
1631
|
-
if F.is_sequence_value_unknown(data_shape) and index:
|
|
1632
|
-
data_shape = F.dyn_shape(data)
|
|
1633
|
-
value = value.astype(data_dtype)
|
|
1634
|
-
data = ops.broadcast_to(value, data_shape)
|
|
1635
|
-
return data
|
|
1636
|
-
value_shape = F.shape(value)
|
|
1637
|
-
source_shape = const_utils.get_source_shape(data_shape, value_shape)
|
|
1225
|
+
value = sequence_to_tensor(value, data_dtype)
|
|
1226
|
+
else:
|
|
1227
|
+
value = F.cast(value, data_dtype)
|
|
1228
|
+
|
|
1638
1229
|
if index:
|
|
1230
|
+
value_shape = F.shape(value)
|
|
1231
|
+
source_shape = const_utils.get_source_shape(data_shape, value_shape)
|
|
1639
1232
|
value = F.reshape(value, source_shape)
|
|
1640
|
-
|
|
1641
|
-
data = F.cast(value, data_dtype)
|
|
1233
|
+
data = F.broadcast_to(value, data_shape)
|
|
1642
1234
|
return data
|
|
1643
1235
|
|
|
1644
1236
|
|
|
@@ -1651,33 +1243,6 @@ def tensor_in_sequence(x, y):
|
|
|
1651
1243
|
return result
|
|
1652
1244
|
|
|
1653
1245
|
|
|
1654
|
-
def format_list_indices(list_indices, length):
|
|
1655
|
-
"""Convert list indices to tensor or tuple indices based on its contents."""
|
|
1656
|
-
indices_types = hyper_map(F.typeof, list_indices)
|
|
1657
|
-
# If eyery element in list is bool, it's treated as 1-D bool tensor.
|
|
1658
|
-
# If every element in list is int(not all bool), it's treated as int tensor.
|
|
1659
|
-
if const_utils.judge_indexes_types(indices_types, mstype.int_type + (mstype.bool_,)):
|
|
1660
|
-
if not F.isconstant(length):
|
|
1661
|
-
return const_utils.sequence_to_index(list_indices, None)
|
|
1662
|
-
return const_utils.sequence_to_index(list_indices, length)
|
|
1663
|
-
# If list contains other types(.../list/tuple/None), it's treated as a tuple
|
|
1664
|
-
return const_utils.deep_tuple(list_indices)
|
|
1665
|
-
|
|
1666
|
-
|
|
1667
|
-
def format_tuple_indices(tuple_indices):
|
|
1668
|
-
"""
|
|
1669
|
-
Format tuple indices by unpacking high-dimension tuple and removing expand
|
|
1670
|
-
dimension signs(Bool and None).
|
|
1671
|
-
"""
|
|
1672
|
-
res = ()
|
|
1673
|
-
for i in tuple_indices:
|
|
1674
|
-
if isinstance(i, (list, tuple)):
|
|
1675
|
-
res += (const_utils.unpack(i),)
|
|
1676
|
-
else:
|
|
1677
|
-
res += (i,)
|
|
1678
|
-
return res
|
|
1679
|
-
|
|
1680
|
-
|
|
1681
1246
|
@_primexpr
|
|
1682
1247
|
def remove_expanded_dims_parse_bool_tensor_index(index_out, indices_out, shapes, cur_dim):
|
|
1683
1248
|
""" Parse bool tensor index """
|