mindspore 2.2.14__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +8 -5
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +124 -25
- mindspore/_extends/builtin_operations.py +2 -1
- mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
- mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
- mindspore/_extends/parse/__init__.py +18 -14
- mindspore/_extends/parse/compile_config.py +299 -0
- mindspore/_extends/parse/namespace.py +2 -2
- mindspore/_extends/parse/parser.py +182 -68
- mindspore/_extends/parse/resources.py +45 -14
- mindspore/_extends/parse/standard_method.py +192 -252
- mindspore/{ops/_op_impl/tbe/atomic_addr_clean.py → _extends/pijit/__init__.py} +6 -16
- mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
- mindspore/_extends/remote/kernel_build_server.py +2 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +67 -26
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/base.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/boost/grad_freeze.py +2 -2
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/boost/less_batch_normalization.py +9 -6
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +20 -7
- mindspore/common/_jit_fallback_utils.py +2 -3
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_adapter.py +7 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +10 -10
- mindspore/common/_stub_tensor.py +7 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/_utils.py +5 -17
- mindspore/common/api.py +449 -129
- mindspore/common/auto_dynamic_shape.py +27 -14
- mindspore/common/dtype.py +17 -10
- mindspore/common/dump.py +8 -11
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +254 -0
- mindspore/common/hook_handle.py +65 -30
- mindspore/common/initializer.py +1 -1
- mindspore/common/jit_config.py +34 -14
- mindspore/common/lazy_inline.py +72 -19
- mindspore/common/mindir_util.py +12 -2
- mindspore/common/mutable.py +79 -14
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +73 -21
- mindspore/common/recompute.py +292 -0
- mindspore/common/seed.py +9 -9
- mindspore/common/sparse_tensor.py +276 -24
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +668 -514
- mindspore/communication/__init__.py +6 -11
- mindspore/communication/_comm_helper.py +43 -3
- mindspore/communication/comm_func.py +1395 -0
- mindspore/communication/management.py +117 -104
- mindspore/config/op_info.config +22 -54
- mindspore/context.py +455 -71
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +6 -6
- mindspore/dataset/audio/transforms.py +711 -158
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +2 -2
- mindspore/dataset/engine/datasets.py +201 -116
- mindspore/dataset/engine/datasets_audio.py +14 -14
- mindspore/dataset/engine/datasets_standard_format.py +83 -3
- mindspore/dataset/engine/datasets_text.py +39 -39
- mindspore/dataset/engine/datasets_user_defined.py +230 -141
- mindspore/dataset/engine/datasets_vision.py +78 -74
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/offload.py +5 -7
- mindspore/dataset/engine/queue.py +138 -66
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +41 -15
- mindspore/dataset/text/__init__.py +2 -5
- mindspore/dataset/text/transforms.py +408 -121
- mindspore/dataset/text/utils.py +9 -9
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/transforms/transforms.py +261 -76
- mindspore/dataset/utils/browse_dataset.py +9 -9
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +7 -10
- mindspore/dataset/vision/c_transforms.py +10 -10
- mindspore/dataset/vision/py_transforms_util.py +1 -1
- mindspore/dataset/vision/transforms.py +2844 -549
- mindspore/dataset/vision/utils.py +161 -10
- mindspore/dataset/vision/validators.py +16 -3
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/{rewrite/ast_creator_register.py → experimental/es/__init__.py} +5 -20
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/experimental/es/embedding_service_layer.py +581 -0
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/experimental/llm_boost/atb/__init__.py +23 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/__init__.py +12 -2
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +35 -34
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +124 -15
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +66 -121
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +18 -8
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +28 -19
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +357 -0
- mindspore/include/api/data_type.h +2 -2
- mindspore/include/api/dual_abi_helper.h +16 -3
- mindspore/include/api/model.h +4 -3
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/status.h +14 -0
- mindspore/include/api/types.h +10 -10
- mindspore/include/c_api/model_c.h +173 -0
- mindspore/include/c_api/types_c.h +19 -0
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +3 -5
- mindspore/include/dataset/vision.h +58 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +3 -3
- mindspore/mindrecord/__init__.py +5 -1
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +25 -0
- mindspore/mindrecord/filewriter.py +138 -103
- mindspore/mindrecord/mindpage.py +40 -6
- mindspore/mindrecord/shardutils.py +3 -2
- mindspore/mindrecord/shardwriter.py +7 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +8 -13
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -15
- mindspore/mindrecord/tools/csv_to_mr.py +4 -9
- mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
- mindspore/mindrecord/tools/mnist_to_mr.py +7 -12
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +1586 -0
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/{rewrite/ast_transformers → mint/linalg}/__init__.py +9 -4
- mindspore/mint/nn/__init__.py +757 -0
- mindspore/mint/nn/functional.py +679 -0
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +206 -0
- mindspore/mint/special/__init__.py +63 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +73 -0
- mindspore/nn/cell.py +461 -323
- mindspore/nn/dynamic_lr.py +2 -2
- mindspore/nn/layer/activation.py +292 -135
- mindspore/nn/layer/basic.py +288 -83
- mindspore/nn/layer/channel_shuffle.py +3 -16
- mindspore/nn/layer/container.py +3 -3
- mindspore/nn/layer/conv.py +75 -66
- mindspore/nn/layer/embedding.py +221 -45
- mindspore/nn/layer/image.py +4 -7
- mindspore/nn/layer/math.py +1 -1
- mindspore/nn/layer/normalization.py +150 -68
- mindspore/nn/layer/padding.py +64 -87
- mindspore/nn/layer/pooling.py +175 -12
- mindspore/nn/layer/rnn_cells.py +6 -16
- mindspore/nn/layer/rnns.py +6 -5
- mindspore/nn/layer/thor_layer.py +1 -2
- mindspore/nn/layer/timedistributed.py +1 -1
- mindspore/nn/layer/transformer.py +55 -53
- mindspore/nn/learning_rate_schedule.py +6 -5
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +145 -88
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/ada_grad.py +4 -2
- mindspore/nn/optim/adadelta.py +4 -2
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +102 -181
- mindspore/nn/optim/adamax.py +4 -2
- mindspore/nn/optim/adasum.py +3 -3
- mindspore/nn/optim/asgd.py +4 -2
- mindspore/nn/optim/ftrl.py +31 -61
- mindspore/nn/optim/lamb.py +5 -3
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +6 -4
- mindspore/nn/optim/momentum.py +13 -25
- mindspore/nn/optim/optimizer.py +6 -3
- mindspore/nn/optim/proximal_ada_grad.py +4 -2
- mindspore/nn/optim/rmsprop.py +9 -3
- mindspore/nn/optim/rprop.py +4 -2
- mindspore/nn/optim/sgd.py +5 -3
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
- mindspore/nn/probability/distribution/beta.py +2 -2
- mindspore/nn/probability/distribution/categorical.py +4 -6
- mindspore/nn/probability/distribution/cauchy.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +2 -2
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +13 -1
- mindspore/nn/wrap/__init__.py +2 -1
- mindspore/nn/wrap/cell_wrapper.py +46 -12
- mindspore/nn/wrap/grad_reducer.py +148 -8
- mindspore/nn/wrap/loss_scale.py +44 -7
- mindspore/numpy/__init__.py +2 -0
- mindspore/numpy/array_creations.py +67 -68
- mindspore/numpy/array_ops.py +70 -66
- mindspore/numpy/dtypes.py +3 -3
- mindspore/numpy/fft.py +966 -0
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +147 -152
- mindspore/numpy/utils.py +3 -0
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +9 -6
- mindspore/ops/_grad_experimental/grad_array_ops.py +4 -129
- mindspore/ops/_grad_experimental/grad_comm_ops.py +135 -36
- mindspore/ops/_grad_experimental/grad_math_ops.py +61 -298
- mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
- mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
- mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/__init__.py +0 -1
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +1 -1
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
- mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -3
- mindspore/ops/_op_impl/cpu/adam.py +2 -2
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
- mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
- mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
- mindspore/ops/_vmap/vmap_array_ops.py +162 -101
- mindspore/ops/_vmap/vmap_base.py +8 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +143 -58
- mindspore/ops/_vmap/vmap_image_ops.py +70 -13
- mindspore/ops/_vmap/vmap_math_ops.py +147 -59
- mindspore/ops/_vmap/vmap_nn_ops.py +292 -117
- mindspore/ops/_vmap/vmap_other_ops.py +1 -1
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
- mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +201 -66
- mindspore/ops/composite/math_ops.py +10 -49
- mindspore/ops/composite/multitype_ops/_compile_utils.py +192 -618
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +25 -134
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +8 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
- mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
- mindspore/ops/deprecated.py +14 -3
- mindspore/ops/function/__init__.py +53 -11
- mindspore/ops/function/array_func.py +1269 -1821
- mindspore/ops/function/clip_func.py +19 -31
- mindspore/ops/function/debug_func.py +114 -5
- mindspore/ops/function/fft_func.py +44 -0
- mindspore/ops/function/grad/grad_func.py +30 -22
- mindspore/ops/function/image_func.py +27 -21
- mindspore/ops/function/linalg_func.py +35 -68
- mindspore/ops/function/math_func.py +1170 -2697
- mindspore/ops/function/nn_func.py +2116 -1128
- mindspore/ops/function/other_func.py +8 -8
- mindspore/ops/function/parameter_func.py +5 -93
- mindspore/ops/function/random_func.py +435 -113
- mindspore/ops/function/reshard_func.py +104 -0
- mindspore/ops/function/sparse_func.py +4 -4
- mindspore/ops/function/sparse_unary_func.py +9 -16
- mindspore/ops/function/spectral_func.py +1 -1
- mindspore/ops/function/vmap_func.py +16 -15
- mindspore/ops/functional.py +355 -346
- mindspore/ops/op_info_register.py +18 -45
- mindspore/ops/operations/__init__.py +38 -24
- mindspore/ops/operations/_grad_ops.py +21 -927
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +173 -607
- mindspore/ops/operations/_rl_inner_ops.py +2 -2
- mindspore/ops/operations/_scalar_ops.py +5 -480
- mindspore/ops/operations/_sequence_ops.py +6 -36
- mindspore/ops/operations/_tensor_array.py +8 -8
- mindspore/ops/operations/array_ops.py +106 -2837
- mindspore/ops/operations/comm_ops.py +799 -127
- mindspore/ops/operations/custom_ops.py +124 -119
- mindspore/ops/operations/debug_ops.py +142 -41
- mindspore/ops/operations/image_ops.py +1 -217
- mindspore/ops/operations/inner_ops.py +5 -40
- mindspore/ops/operations/linalg_ops.py +1 -49
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +73 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
- mindspore/ops/operations/math_ops.py +666 -4972
- mindspore/ops/operations/nn_ops.py +205 -2213
- mindspore/ops/operations/other_ops.py +60 -49
- mindspore/ops/operations/random_ops.py +50 -54
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/sparse_ops.py +4 -4
- mindspore/ops/primitive.py +216 -103
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +252 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +1099 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +1052 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +145 -0
- mindspore/ops_generate/pyboost_utils.py +367 -0
- mindspore/ops_generate/template.py +261 -0
- mindspore/parallel/__init__.py +8 -4
- mindspore/parallel/_auto_parallel_context.py +100 -10
- mindspore/parallel/_cell_wrapper.py +99 -9
- mindspore/parallel/_cost_model_context.py +1 -1
- mindspore/parallel/_dp_allreduce_fusion.py +159 -159
- mindspore/parallel/_parallel_serialization.py +67 -23
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +99 -22
- mindspore/parallel/_transformer/__init__.py +1 -1
- mindspore/parallel/_transformer/layers.py +1 -1
- mindspore/parallel/_transformer/loss.py +1 -1
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/op_parallel_config.py +1 -1
- mindspore/parallel/_transformer/transformer.py +2 -2
- mindspore/parallel/_utils.py +173 -6
- mindspore/parallel/algo_parameter_config.py +8 -10
- mindspore/parallel/checkpoint_transform.py +204 -38
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +352 -0
- mindspore/parallel/cluster/process_entity/_utils.py +101 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +1 -1
- mindspore/parallel/mpi/_mpi_config.py +1 -1
- mindspore/parallel/parameter_broadcast.py +151 -0
- mindspore/parallel/shard.py +279 -37
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +4 -2
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/util.py +153 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +18 -20
- mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
- mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +14 -9
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -1
- mindspore/profiler/parser/ascend_flops_generator.py +20 -4
- mindspore/profiler/parser/ascend_hccl_generator.py +29 -278
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +148 -146
- mindspore/profiler/parser/ascend_msprof_generator.py +73 -283
- mindspore/profiler/parser/ascend_op_generator.py +92 -42
- mindspore/profiler/parser/ascend_timeline_generator.py +298 -133
- mindspore/profiler/parser/base_timeline_generator.py +25 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +4 -393
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/integrator.py +3 -1
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/minddata_parser.py +72 -3
- mindspore/profiler/parser/profiler_info.py +94 -7
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +631 -508
- mindspore/rewrite/__init__.py +2 -14
- mindspore/rewrite/api/node.py +122 -36
- mindspore/rewrite/api/pattern_engine.py +2 -3
- mindspore/rewrite/api/scoped_value.py +16 -15
- mindspore/rewrite/api/symbol_tree.py +45 -29
- mindspore/rewrite/ast_helpers/__init__.py +3 -6
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
- mindspore/rewrite/common/__init__.py +1 -2
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
- mindspore/rewrite/{namer.py → common/namer.py} +63 -18
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/node/__init__.py +5 -5
- mindspore/rewrite/node/call_function.py +23 -7
- mindspore/rewrite/node/cell_container.py +7 -3
- mindspore/rewrite/node/control_flow.py +53 -28
- mindspore/rewrite/node/node.py +212 -196
- mindspore/rewrite/node/node_manager.py +51 -22
- mindspore/rewrite/node/node_topological_manager.py +3 -23
- mindspore/rewrite/parsers/__init__.py +12 -0
- mindspore/rewrite/parsers/arguments_parser.py +8 -9
- mindspore/rewrite/parsers/assign_parser.py +637 -413
- mindspore/rewrite/parsers/attribute_parser.py +3 -4
- mindspore/rewrite/parsers/class_def_parser.py +115 -148
- mindspore/rewrite/parsers/constant_parser.py +5 -5
- mindspore/rewrite/parsers/container_parser.py +4 -6
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +31 -98
- mindspore/rewrite/parsers/function_def_parser.py +13 -5
- mindspore/rewrite/parsers/if_parser.py +28 -10
- mindspore/rewrite/parsers/module_parser.py +8 -182
- mindspore/rewrite/parsers/parser.py +1 -5
- mindspore/rewrite/parsers/parser_register.py +1 -1
- mindspore/rewrite/parsers/return_parser.py +5 -10
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +705 -186
- mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
- mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
- mindspore/run_check/_check_version.py +40 -115
- mindspore/run_check/run_check.py +1 -1
- mindspore/safeguard/rewrite_obfuscation.py +597 -263
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +7 -5
- mindspore/train/_utils.py +204 -4
- mindspore/train/amp.py +335 -295
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/__init__.py +5 -2
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +220 -43
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_flops_collector.py +239 -0
- mindspore/train/callback/_landscape.py +15 -9
- mindspore/train/callback/_loss_monitor.py +5 -5
- mindspore/train/callback/_on_request_exit.py +136 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
- mindspore/train/callback/_summary_collector.py +12 -12
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/data_sink.py +6 -5
- mindspore/train/dataset_helper.py +66 -23
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/accuracy.py +7 -7
- mindspore/train/metrics/confusion_matrix.py +8 -6
- mindspore/train/metrics/cosine_similarity.py +6 -4
- mindspore/train/metrics/error.py +2 -2
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/perplexity.py +2 -1
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/metrics/topk.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +382 -76
- mindspore/train/serialization.py +787 -288
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +51 -28
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/METADATA +8 -4
- mindspore-2.4.0.dist-info/RECORD +1406 -0
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +1 -0
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
- mindspore/gen_ops.py +0 -273
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -282
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/layer/flash_attention.py +0 -189
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/_op_impl/cpu/concat.py +0 -39
- mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
- mindspore/ops/_op_impl/tbe/__init__.py +0 -47
- mindspore/ops/_op_impl/tbe/abs.py +0 -38
- mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/acos.py +0 -37
- mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/acosh.py +0 -37
- mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
- mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
- mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
- mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
- mindspore/ops/_op_impl/tbe/add.py +0 -42
- mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/add_n.py +0 -39
- mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
- mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
- mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
- mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
- mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
- mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
- mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
- mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/asin.py +0 -37
- mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/asinh.py +0 -37
- mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/assign.py +0 -79
- mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
- mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
- mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/atan.py +0 -37
- mindspore/ops/_op_impl/tbe/atan2.py +0 -38
- mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/atanh.py +0 -37
- mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
- mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
- mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
- mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
- mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
- mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
- mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
- mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
- mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
- mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
- mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cast.py +0 -55
- mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/cdist.py +0 -38
- mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/ceil.py +0 -37
- mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/celu.py +0 -39
- mindspore/ops/_op_impl/tbe/centralization.py +0 -39
- mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
- mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/concat.py +0 -40
- mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
- mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
- mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
- mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
- mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
- mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/cos.py +0 -37
- mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/cosh.py +0 -37
- mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
- mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cummin.py +0 -41
- mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
- mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
- mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
- mindspore/ops/_op_impl/tbe/diag.py +0 -38
- mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
- mindspore/ops/_op_impl/tbe/dilation.py +0 -40
- mindspore/ops/_op_impl/tbe/div.py +0 -41
- mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
- mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
- mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
- mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
- mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
- mindspore/ops/_op_impl/tbe/elu.py +0 -38
- mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/equal.py +0 -42
- mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/erf.py +0 -37
- mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfc.py +0 -37
- mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
- mindspore/ops/_op_impl/tbe/exp.py +0 -40
- mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
- mindspore/ops/_op_impl/tbe/expm1.py +0 -37
- mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
- mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/fill.py +0 -56
- mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/flatten.py +0 -48
- mindspore/ops/_op_impl/tbe/floor.py +0 -37
- mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
- mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
- mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
- mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
- mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
- mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
- mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
- mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/ger.py +0 -43
- mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/greater.py +0 -43
- mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
- mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
- mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
- mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
- mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
- mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
- mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
- mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/im2col.py +0 -42
- mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
- mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
- mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/inv.py +0 -38
- mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/invert.py +0 -37
- mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/iou.py +0 -38
- mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/is_close.py +0 -40
- mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
- mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
- mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
- mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
- mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
- mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
- mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/lerp.py +0 -38
- mindspore/ops/_op_impl/tbe/less.py +0 -41
- mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/log.py +0 -40
- mindspore/ops/_op_impl/tbe/log1p.py +0 -37
- mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
- mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
- mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
- mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
- mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/matmul.py +0 -53
- mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
- mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
- mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
- mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum.py +0 -39
- mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
- mindspore/ops/_op_impl/tbe/minimum.py +0 -40
- mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mish.py +0 -37
- mindspore/ops/_op_impl/tbe/mod.py +0 -41
- mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/mul.py +0 -37
- mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
- mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
- mindspore/ops/_op_impl/tbe/neg.py +0 -39
- mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
- mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
- mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
- mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
- mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
- mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/pack.py +0 -58
- mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
- mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
- mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/pdist.py +0 -36
- mindspore/ops/_op_impl/tbe/pooling.py +0 -46
- mindspore/ops/_op_impl/tbe/population_count.py +0 -38
- mindspore/ops/_op_impl/tbe/pow.py +0 -41
- mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/prelu.py +0 -37
- mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/range.py +0 -39
- mindspore/ops/_op_impl/tbe/real_div.py +0 -38
- mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
- mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
- mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
- mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
- mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6.py +0 -38
- mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/renorm.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
- mindspore/ops/_op_impl/tbe/rint.py +0 -37
- mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roll.py +0 -42
- mindspore/ops/_op_impl/tbe/round.py +0 -38
- mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
- mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
- mindspore/ops/_op_impl/tbe/select.py +0 -38
- mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/selu.py +0 -39
- mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sgd.py +0 -62
- mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sign.py +0 -38
- mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/sin.py +0 -37
- mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sinh.py +0 -37
- mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/slice.py +0 -58
- mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
- mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax.py +0 -37
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
- mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/softplus.py +0 -37
- mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softsign.py +0 -37
- mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sort.py +0 -38
- mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/split_d.py +0 -38
- mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/split_v.py +0 -39
- mindspore/ops/_op_impl/tbe/splitv.py +0 -39
- mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/square.py +0 -38
- mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
- mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
- mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
- mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
- mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
- mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
- mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
- mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
- mindspore/ops/_op_impl/tbe/sub.py +0 -39
- mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tan.py +0 -38
- mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh.py +0 -37
- mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
- mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
- mindspore/ops/_op_impl/tbe/tile.py +0 -37
- mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
- mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
- mindspore/ops/_op_impl/tbe/transpose.py +0 -60
- mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
- mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
- mindspore/ops/_op_impl/tbe/trunc.py +0 -39
- mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/unpack.py +0 -38
- mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
- mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
- mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
- mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
- mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
- mindspore/ops/_tracefunc.py +0 -241
- mindspore/ops/arg_dtype_cast.py +0 -54
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/rewrite/api/tree_node_helper.py +0 -60
- mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
- mindspore/rewrite/namespace.py +0 -53
- mindspore-2.2.14.dist-info/RECORD +0 -1924
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
|
@@ -24,13 +24,13 @@ from mindspore.ops import operations as P
|
|
|
24
24
|
from mindspore.ops.composite import base
|
|
25
25
|
from mindspore.ops._primitive_cache import _get_cache_prim
|
|
26
26
|
from mindspore.ops.operations._inner_ops import TensorCopySlices, SliceGetItem, \
|
|
27
|
-
TopTypeof,
|
|
27
|
+
TopTypeof, IsParameter, GetitemTensorIndexInfo, SetitemTensorIndexInfo, \
|
|
28
28
|
SelectView, CopyWithSlice
|
|
29
|
+
from mindspore.ops.operations._sequence_ops import TensorToTuple, TensorToScalar, TupleToTensor
|
|
29
30
|
from mindspore.common import dtype as mstype
|
|
30
31
|
from mindspore.common._register_for_tensor import tensor_operator_registry
|
|
31
32
|
from mindspore.common.initializer import Zero
|
|
32
|
-
from mindspore.common import Tensor, CSRTensor, COOTensor
|
|
33
|
-
from mindspore.common import mutable
|
|
33
|
+
from mindspore.common import Tensor, CSRTensor, COOTensor, mutable
|
|
34
34
|
from mindspore import ops
|
|
35
35
|
from mindspore.ops.primitive import _primexpr
|
|
36
36
|
from mindspore import _checkparam as validator
|
|
@@ -137,7 +137,8 @@ def data_update_by_ops(transfer_type, arg, data, new_index, origin_data, value=N
|
|
|
137
137
|
elif transfer_type == ValueTransferType.kGatherND:
|
|
138
138
|
if isinstance(new_index, list):
|
|
139
139
|
new_index = handle_multi_dim_index_tensor(new_index, arg)
|
|
140
|
-
|
|
140
|
+
new_index = format_index_tensor(new_index, (None, F.shape(data)[:F.shape(new_index)[-1]]))
|
|
141
|
+
data = F.gather_nd(data, new_index)
|
|
141
142
|
elif transfer_type == ValueTransferType.kTensorScatterUpdate:
|
|
142
143
|
if isinstance(new_index, list):
|
|
143
144
|
new_index = handle_multi_dim_index_tensor(new_index, arg)
|
|
@@ -217,8 +218,8 @@ def _tensor_setitem(self, index, value):
|
|
|
217
218
|
return output
|
|
218
219
|
|
|
219
220
|
|
|
220
|
-
tensor_operator_registry
|
|
221
|
-
tensor_operator_registry
|
|
221
|
+
setattr(tensor_operator_registry, "__getitem__", _tensor_getitem)
|
|
222
|
+
setattr(tensor_operator_registry, "__setitem__", _tensor_setitem)
|
|
222
223
|
|
|
223
224
|
|
|
224
225
|
def _tensor_add(self, other):
|
|
@@ -287,15 +288,15 @@ def _tensor_floordiv(self, other):
|
|
|
287
288
|
return F.floordiv(self, other)
|
|
288
289
|
|
|
289
290
|
|
|
290
|
-
tensor_operator_registry
|
|
291
|
-
tensor_operator_registry
|
|
292
|
-
tensor_operator_registry
|
|
293
|
-
tensor_operator_registry
|
|
294
|
-
tensor_operator_registry
|
|
295
|
-
tensor_operator_registry
|
|
296
|
-
tensor_operator_registry
|
|
297
|
-
tensor_operator_registry
|
|
298
|
-
tensor_operator_registry
|
|
291
|
+
setattr(tensor_operator_registry, '__add__', _tensor_add)
|
|
292
|
+
setattr(tensor_operator_registry, '__sub__', _tensor_sub)
|
|
293
|
+
setattr(tensor_operator_registry, '__mul__', _tensor_mul)
|
|
294
|
+
setattr(tensor_operator_registry, '__matmul__', _tensor_matmul)
|
|
295
|
+
setattr(tensor_operator_registry, '__truediv__', _tensor_div)
|
|
296
|
+
setattr(tensor_operator_registry, '__mod__', _tensor_mod)
|
|
297
|
+
setattr(tensor_operator_registry, '__pow__', _tensor_pow)
|
|
298
|
+
setattr(tensor_operator_registry, '__rpow__', _tensor_rpow)
|
|
299
|
+
setattr(tensor_operator_registry, '__floordiv__', _tensor_floordiv)
|
|
299
300
|
|
|
300
301
|
|
|
301
302
|
def _scalar_to_tensor(input_x):
|
|
@@ -317,24 +318,25 @@ def tensor_item(data, *args):
|
|
|
317
318
|
# transform a.item(tuple(int)) -> a.item(int1,int2...intN)
|
|
318
319
|
if data.ndim == 0:
|
|
319
320
|
_check_scalar_tensor_args(args)
|
|
320
|
-
return
|
|
321
|
+
return TensorToScalar()(data)
|
|
321
322
|
if len(args) == 1 and isinstance(args[0], tuple):
|
|
322
323
|
args = args[0]
|
|
323
324
|
|
|
324
325
|
args_types = hyper_map(F.typeof, args)
|
|
325
326
|
if not args or const_utils.judge_index_type(args_types[0], mstype.type_none):
|
|
326
327
|
if data.shape == (1,):
|
|
327
|
-
return
|
|
328
|
+
return TensorToScalar()(data[0])
|
|
328
329
|
const_utils.raise_value_error("Can only convert an array of size 1 to a Python scalar")
|
|
329
330
|
|
|
330
331
|
if not const_utils.judge_indexes_types(args_types, mstype.int64):
|
|
331
332
|
const_utils.raise_type_error("The index object cannot be interpreted as an integer")
|
|
332
333
|
|
|
333
334
|
if len(args) == data.ndim:
|
|
334
|
-
return
|
|
335
|
+
return tensor_index_by_tuple(data, args)
|
|
335
336
|
if len(args) > 1:
|
|
336
337
|
const_utils.raise_value_error("Incorrect number of indices for array")
|
|
337
|
-
|
|
338
|
+
output = _tensor_index_by_integer(F.reshape(data, (-1,)), args[0])
|
|
339
|
+
return TensorToScalar()(output)
|
|
338
340
|
|
|
339
341
|
|
|
340
342
|
def tensor_itemset(data, *args):
|
|
@@ -354,8 +356,8 @@ def tensor_itemset(data, *args):
|
|
|
354
356
|
return tensor_itemset_with_number(data, args[0])
|
|
355
357
|
|
|
356
358
|
|
|
357
|
-
tensor_operator_registry
|
|
358
|
-
tensor_operator_registry
|
|
359
|
+
setattr(tensor_operator_registry, "item", tensor_item)
|
|
360
|
+
setattr(tensor_operator_registry, "itemset", tensor_itemset)
|
|
359
361
|
|
|
360
362
|
|
|
361
363
|
def tensor_itemset_with_number(data, number_value):
|
|
@@ -481,6 +483,7 @@ def format_index_tensor(index, arg):
|
|
|
481
483
|
index[format_idx] = F.select(index_tensor < 0, index_tensor + format_dim, index_tensor)
|
|
482
484
|
return index
|
|
483
485
|
index = Tensor(index)
|
|
486
|
+
format_dims = Tensor(format_dims)
|
|
484
487
|
return F.select(index < 0, index + format_dims, index)
|
|
485
488
|
|
|
486
489
|
|
|
@@ -521,24 +524,45 @@ def _expand_data_dims(data, tuple_index):
|
|
|
521
524
|
return data, tuple_index_new
|
|
522
525
|
|
|
523
526
|
|
|
524
|
-
def
|
|
525
|
-
"""convert
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
527
|
+
def _convert_list_index_to_tensor(list_index):
|
|
528
|
+
"""convert list to tensor"""
|
|
529
|
+
has_bool = False
|
|
530
|
+
has_int = False
|
|
531
|
+
has_no_bool_int = False
|
|
532
|
+
for idx in list_index:
|
|
533
|
+
if isinstance(idx, bool):
|
|
534
|
+
has_bool = True
|
|
535
|
+
elif isinstance(idx, int):
|
|
536
|
+
has_int = True
|
|
537
|
+
else:
|
|
538
|
+
has_no_bool_int = True
|
|
539
|
+
|
|
540
|
+
all_bool = has_bool and not has_int and not has_no_bool_int
|
|
541
|
+
all_int = has_int and not has_bool and not has_no_bool_int
|
|
542
|
+
all_bool_or_int = not has_no_bool_int
|
|
543
|
+
|
|
544
|
+
if all_int:
|
|
545
|
+
index_tensor = TupleToTensor()(tuple(list_index), mstype.int64)
|
|
546
|
+
return index_tensor
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
if all_bool:
|
|
550
|
+
index_tensor = TupleToTensor()(tuple(list_index), mstype.bool_)
|
|
551
|
+
return index_tensor
|
|
552
|
+
|
|
553
|
+
# convert bool to int if index is mixture of (bool, int)
|
|
554
|
+
if all_bool_or_int:
|
|
555
|
+
new_index = []
|
|
556
|
+
for idx in list_index:
|
|
557
|
+
if isinstance(idx, bool):
|
|
558
|
+
new_idx = int(idx)
|
|
559
|
+
new_index.append(new_idx)
|
|
560
|
+
else:
|
|
561
|
+
new_index.append(idx)
|
|
562
|
+
index_tensor = TupleToTensor()(tuple(new_index), mstype.int64)
|
|
563
|
+
return index_tensor
|
|
564
|
+
|
|
565
|
+
return None
|
|
542
566
|
|
|
543
567
|
|
|
544
568
|
class _TensorIndexGetitem(base.TensorIndexGetitem_):
|
|
@@ -564,26 +588,6 @@ def tensor_index_by_slice(data, slice_index):
|
|
|
564
588
|
return _tensor_index_getitem(data, slice_index)
|
|
565
589
|
|
|
566
590
|
|
|
567
|
-
def get_stride_info_from_slice(data, slice_index):
|
|
568
|
-
"""get the stride info from slice index"""
|
|
569
|
-
data_shape = F.dyn_shape(data)
|
|
570
|
-
begin_strides, end_strides, step_strides = [], [], []
|
|
571
|
-
start, stop, step = get_slice_stride(slice_index, data_shape[0])
|
|
572
|
-
if start.ndim > 0:
|
|
573
|
-
start = start.item()
|
|
574
|
-
if stop.ndim > 0:
|
|
575
|
-
stop = stop.item()
|
|
576
|
-
if step.ndim > 0:
|
|
577
|
-
step = step.item()
|
|
578
|
-
begin_strides.append(start)
|
|
579
|
-
end_strides.append(stop)
|
|
580
|
-
step_strides.append(step)
|
|
581
|
-
begin_tensor = stack(begin_strides)
|
|
582
|
-
end_tensor = stack(end_strides)
|
|
583
|
-
step_tensor = stack(step_strides)
|
|
584
|
-
return begin_tensor, end_tensor, step_tensor
|
|
585
|
-
|
|
586
|
-
|
|
587
591
|
def tensor_index_by_number(data, number_index):
|
|
588
592
|
"""Tensor getitem by a Number which may be integer/float/bool value"""
|
|
589
593
|
if isinstance(number_index, bool):
|
|
@@ -607,31 +611,18 @@ def _tensor_index_by_bool(data, bool_value):
|
|
|
607
611
|
return output
|
|
608
612
|
|
|
609
613
|
|
|
610
|
-
def get_stride_info_from_integer(
|
|
614
|
+
def get_stride_info_from_integer(int_index):
|
|
611
615
|
"""Convert integer to slice"""
|
|
612
|
-
begin_strides =
|
|
613
|
-
end_strides =
|
|
614
|
-
step_strides =
|
|
615
|
-
|
|
616
|
-
end_tensor = stack(end_strides)
|
|
617
|
-
step_tensor = stack(step_strides)
|
|
618
|
-
return begin_tensor, end_tensor, step_tensor
|
|
616
|
+
begin_strides = (int_index,)
|
|
617
|
+
end_strides = (int_index + 1,)
|
|
618
|
+
step_strides = (1,)
|
|
619
|
+
return begin_strides, end_strides, step_strides
|
|
619
620
|
|
|
620
621
|
|
|
621
622
|
def _tensor_index_by_integer(data, int_index):
|
|
622
623
|
"""Tensor getitem by a single integer number"""
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
tensor_index = _scalar_to_tensor(int_index)
|
|
626
|
-
begin_strides, end_strides, step_strides = get_stride_info_from_integer(tensor_index)
|
|
627
|
-
else:
|
|
628
|
-
if not data_shape:
|
|
629
|
-
const_utils.raise_type_error("Cannot iterate over a scalar tensor.")
|
|
630
|
-
if data.ndim < 1 or data.ndim > 8:
|
|
631
|
-
const_utils.raise_value_error("Expect Tensor to have dimension between 1 and 8.")
|
|
632
|
-
transformed_number = const_utils.check_range(int_index, data_shape[0])
|
|
633
|
-
begin_strides, end_strides, step_strides = \
|
|
634
|
-
const_utils.get_stride_info_from_integer(data_shape, transformed_number)
|
|
624
|
+
begin_strides, end_strides, step_strides = get_stride_info_from_integer(int_index)
|
|
625
|
+
|
|
635
626
|
shrink_axis_mask = 1
|
|
636
627
|
begin_mask = 0
|
|
637
628
|
end_mask = 0
|
|
@@ -664,6 +655,7 @@ def tensor_index_by_tensor(data, tensor_index):
|
|
|
664
655
|
if not F.is_sequence_value_unknown(F.shape(data)):
|
|
665
656
|
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
666
657
|
if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Int):
|
|
658
|
+
tensor_index = F.select(tensor_index < 0, tensor_index + F.shape(data)[0], tensor_index)
|
|
667
659
|
return F.gather(data, tensor_index, 0)
|
|
668
660
|
if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Bool):
|
|
669
661
|
return tensor_index_by_bool_tensor(data, tensor_index)
|
|
@@ -676,27 +668,23 @@ def tensor_index_by_tensor(data, tensor_index):
|
|
|
676
668
|
def tensor_index_by_list(data, list_index):
|
|
677
669
|
"""Tensor getitem by list of int and bool"""
|
|
678
670
|
min_data_dim, max_data_dim = 1, 8
|
|
679
|
-
|
|
671
|
+
if F.isconstant(data.ndim):
|
|
672
|
+
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
680
673
|
|
|
681
674
|
data_shape = F.shape(data)
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
list_index, data_shape[0])
|
|
696
|
-
if tensor_index is False:
|
|
697
|
-
const_utils.raise_index_error(
|
|
698
|
-
"When tensor is indexed by list, the list can't be empty.")
|
|
699
|
-
return F.gather(data, tensor_index, 0)
|
|
675
|
+
if F.isconstant(data_shape[0]) and all(isinstance(i, bool) for i in list_index):
|
|
676
|
+
if data_shape[0] != len(list_index):
|
|
677
|
+
raise IndexError(
|
|
678
|
+
f'dimension is {data_shape[0]} but corresponding boolean dimension is {len(list_index)}')
|
|
679
|
+
tensor_index = Tensor(list_index).nonzero()
|
|
680
|
+
return F.gather_nd(data, tensor_index)
|
|
681
|
+
|
|
682
|
+
if not list_index:
|
|
683
|
+
const_utils.raise_index_error("When tensor is indexed by list, the list can't be empty.")
|
|
684
|
+
|
|
685
|
+
index_tensor = _convert_list_index_to_tensor(list_index)
|
|
686
|
+
if index_tensor is not None:
|
|
687
|
+
return tensor_index_by_tensor(data, index_tensor)
|
|
700
688
|
|
|
701
689
|
tuple_index_new = ()
|
|
702
690
|
for index in list_index:
|
|
@@ -704,16 +692,6 @@ def tensor_index_by_list(data, list_index):
|
|
|
704
692
|
return tensor_index_by_tuple(data, tuple_index_new)
|
|
705
693
|
|
|
706
694
|
|
|
707
|
-
def convert_tupleslice_to_tensor(tuple_index):
|
|
708
|
-
"""convert mutable scalar in slice to tensor"""
|
|
709
|
-
new_tuple_index = []
|
|
710
|
-
for item in tuple_index:
|
|
711
|
-
if isinstance(item, slice):
|
|
712
|
-
item = convert_variable_to_tensor_slice(item)
|
|
713
|
-
new_tuple_index.append(item)
|
|
714
|
-
return tuple(new_tuple_index)
|
|
715
|
-
|
|
716
|
-
|
|
717
695
|
def judge_tuple_index_dim_check_error(index_dim, data_dim):
|
|
718
696
|
"""raise IndexError when tuple_index's dim is invalid"""
|
|
719
697
|
if index_dim > data_dim:
|
|
@@ -721,29 +699,6 @@ def judge_tuple_index_dim_check_error(index_dim, data_dim):
|
|
|
721
699
|
f"dim of index:{index_dim}, dim of data:{data_dim}")
|
|
722
700
|
|
|
723
701
|
|
|
724
|
-
class _HandleEmptySlice(base.HandleEmptySlice_):
|
|
725
|
-
"""
|
|
726
|
-
Getting item of Tensor.
|
|
727
|
-
|
|
728
|
-
Args:
|
|
729
|
-
data (Tensor): A tuple to be sliced.
|
|
730
|
-
index: Index of tensor.
|
|
731
|
-
|
|
732
|
-
Returns:
|
|
733
|
-
Type is the same as the element type of data.
|
|
734
|
-
"""
|
|
735
|
-
|
|
736
|
-
def __init__(self, name):
|
|
737
|
-
"""Initialize _HandleEmptySlice."""
|
|
738
|
-
base.HandleEmptySlice_.__init__(self, name)
|
|
739
|
-
|
|
740
|
-
def __call__(self, *args):
|
|
741
|
-
pass
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
_handle_empty_slice = _HandleEmptySlice('handle_zero_tuple_index')
|
|
745
|
-
|
|
746
|
-
|
|
747
702
|
def judge_tuple_index_dim(data, tuple_index):
|
|
748
703
|
"""Judge whether tuple_index's dim is valid"""
|
|
749
704
|
data_dim = data.ndim
|
|
@@ -756,50 +711,20 @@ def judge_tuple_index_dim(data, tuple_index):
|
|
|
756
711
|
judge_tuple_index_dim_check_error(index_dim, data_dim)
|
|
757
712
|
|
|
758
713
|
|
|
759
|
-
def judge_simple_tuple_index(data, tuple_index):
|
|
760
|
-
"""Judge whether tuple_index is simple index, which not rollback to cpu ops."""
|
|
761
|
-
op_name = const_utils.TENSOR_GETITEM
|
|
762
|
-
indexes_types = hyper_map(toptypeof, tuple_index)
|
|
763
|
-
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
|
|
764
|
-
return F.isconstant(tuple_index) and contain_type == const_utils.ALL_BASIC \
|
|
765
|
-
and F.is_sequence_value_unknown(F.shape(data)) and F.isconstant(F.rank(data))
|
|
766
|
-
|
|
767
|
-
|
|
768
714
|
def tensor_index_by_tuple(data, tuple_index):
|
|
769
715
|
"""Tensor getitem by tuple of various types with None"""
|
|
770
716
|
if not tuple_index:
|
|
771
717
|
return data
|
|
772
|
-
if judge_simple_tuple_index(data, tuple_index):
|
|
773
|
-
tuple_index = convert_tupleslice_to_tensor(tuple_index)
|
|
774
|
-
op_name = const_utils.TENSOR_GETITEM
|
|
775
|
-
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
|
776
|
-
min_data_dim, max_data_dim = 1, 8
|
|
777
|
-
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
778
|
-
return _tensor_getitem_by_tuple_slice(data, tuple_index)
|
|
779
718
|
|
|
780
719
|
if not F.is_sequence_value_unknown(F.shape(data)):
|
|
781
720
|
judge_tuple_index_dim(data, tuple_index)
|
|
782
721
|
tuple_index, zero_index, non_zero_shapes = _handle_bool_tensor(tuple_index)
|
|
783
722
|
for non_zero_shape in non_zero_shapes:
|
|
784
|
-
if
|
|
723
|
+
if 0 in non_zero_shape:
|
|
785
724
|
tuple_index = zero_index
|
|
786
725
|
break
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
if 0 in stub_zero_dim_tensor.shape:
|
|
790
|
-
return F.fill(data.dtype, stub_zero_dim_tensor.shape, 0)
|
|
791
|
-
has_tensor_index = False
|
|
792
|
-
for i in tuple_index:
|
|
793
|
-
if isinstance(i, Tensor):
|
|
794
|
-
has_tensor_index = True
|
|
795
|
-
break
|
|
796
|
-
empty_broadcast_data_shape = False
|
|
797
|
-
_broadcast_data_shape = _handle_scalar_tensor_index(data, tuple_index)
|
|
798
|
-
if has_tensor_index and isinstance(_broadcast_data_shape, Tensor) and _broadcast_data_shape == Tensor([0]):
|
|
799
|
-
empty_broadcast_data_shape = True
|
|
800
|
-
if has_tensor_index and isinstance(_broadcast_data_shape, tuple) and not _broadcast_data_shape:
|
|
801
|
-
empty_broadcast_data_shape = True
|
|
802
|
-
return _tensor_index_getitem(data, tuple_index, empty_broadcast_data_shape)
|
|
726
|
+
|
|
727
|
+
return _tensor_index_getitem(data, tuple_index)
|
|
803
728
|
|
|
804
729
|
|
|
805
730
|
def get_slice_stride(slice_index, dim_size):
|
|
@@ -809,20 +734,20 @@ def get_slice_stride(slice_index, dim_size):
|
|
|
809
734
|
step = slice_get_item(slice_index, "step")
|
|
810
735
|
|
|
811
736
|
if start is None:
|
|
812
|
-
start =
|
|
737
|
+
start = 0
|
|
813
738
|
if stop is None:
|
|
814
739
|
stop = dim_size
|
|
815
740
|
if step is None:
|
|
816
|
-
step =
|
|
741
|
+
step = 1
|
|
817
742
|
|
|
818
|
-
if
|
|
819
|
-
start =
|
|
743
|
+
if isinstance(start, Tensor):
|
|
744
|
+
start = int(start)
|
|
820
745
|
|
|
821
|
-
if
|
|
822
|
-
stop =
|
|
746
|
+
if isinstance(stop, Tensor):
|
|
747
|
+
stop = int(stop)
|
|
823
748
|
|
|
824
|
-
if
|
|
825
|
-
step =
|
|
749
|
+
if isinstance(step, Tensor):
|
|
750
|
+
step = int(step)
|
|
826
751
|
|
|
827
752
|
return start, stop, step
|
|
828
753
|
|
|
@@ -841,190 +766,6 @@ def cal_tuple_slice_mask(data_shape, tuple_index):
|
|
|
841
766
|
return begin_mask, end_mask
|
|
842
767
|
|
|
843
768
|
|
|
844
|
-
def _get_stride_info_from_tuple(data, tuple_index):
|
|
845
|
-
"""get the stride info from tuple"""
|
|
846
|
-
data_shape = F.dyn_shape(data)
|
|
847
|
-
begin_strides, end_strides, step_strides = [], [], []
|
|
848
|
-
tuple_index_len = len(tuple_index)
|
|
849
|
-
data_dim = data.ndim
|
|
850
|
-
shrink_axis, index_count, ellipsis_count = 0, 0, 0
|
|
851
|
-
for item in range(data_dim):
|
|
852
|
-
if item >= tuple_index_len or item >= data_dim:
|
|
853
|
-
break
|
|
854
|
-
index = tuple_index[item]
|
|
855
|
-
dim_size = data_shape[item]
|
|
856
|
-
if isinstance(index, slice):
|
|
857
|
-
start, stop, step = get_slice_stride(index, dim_size)
|
|
858
|
-
begin_strides.append(start)
|
|
859
|
-
end_strides.append(stop)
|
|
860
|
-
step_strides.append(step)
|
|
861
|
-
index_count = index_count + 1
|
|
862
|
-
elif isinstance(index, int):
|
|
863
|
-
int_tensor = _scalar_to_tensor(index)
|
|
864
|
-
begin_strides.append(int_tensor)
|
|
865
|
-
end_strides.append(int_tensor + const_utils.make_tensor(1))
|
|
866
|
-
step_strides.append(const_utils.make_tensor(1))
|
|
867
|
-
shrink_axis = shrink_axis + (2 ** index_count)
|
|
868
|
-
index_count = index_count + 1
|
|
869
|
-
elif index is ...:
|
|
870
|
-
ellipsis_count = ellipsis_count + 1
|
|
871
|
-
if ellipsis_count > 1:
|
|
872
|
-
const_utils.raise_value_error("An index can have only one ellipsis (...)")
|
|
873
|
-
ellipsis_range_size = data_dim - tuple_index_len + 1
|
|
874
|
-
begin_strides.extend([const_utils.make_tensor(0)] * ellipsis_range_size)
|
|
875
|
-
end_strides.extend(
|
|
876
|
-
[shape for shape in data_shape[index_count: index_count + ellipsis_range_size]])
|
|
877
|
-
step_strides.extend([const_utils.make_tensor(1)] * ellipsis_range_size)
|
|
878
|
-
index_count = index_count + ellipsis_range_size
|
|
879
|
-
else:
|
|
880
|
-
exp_msg = const_utils.gen_exception_msg("Not supported index data type, got {}, type is {}", index,
|
|
881
|
-
type(index))
|
|
882
|
-
const_utils.raise_index_error(exp_msg)
|
|
883
|
-
begin_tensor = stack(begin_strides)
|
|
884
|
-
end_tensor = stack(end_strides)
|
|
885
|
-
step_tensor = stack(step_strides)
|
|
886
|
-
strides_v = {
|
|
887
|
-
'begin': begin_tensor,
|
|
888
|
-
'end': end_tensor,
|
|
889
|
-
'step': step_tensor
|
|
890
|
-
}
|
|
891
|
-
return strides_v, shrink_axis
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
def _tensor_getitem_by_tuple_slice(data, tuple_index):
|
|
895
|
-
"""Tensor getitem by a tuple of slice"""
|
|
896
|
-
data_shape = F.shape(data)
|
|
897
|
-
is_dynamic = F.is_sequence_value_unknown(data_shape)
|
|
898
|
-
for item in tuple_index:
|
|
899
|
-
if isinstance(item, slice):
|
|
900
|
-
is_dynamic = is_dynamic or isinstance(slice_get_item(item, "start"), Tensor) \
|
|
901
|
-
or isinstance(slice_get_item(item, "stop"), Tensor) \
|
|
902
|
-
or isinstance(slice_get_item(item, "step"), Tensor)
|
|
903
|
-
|
|
904
|
-
strides_v = {}
|
|
905
|
-
shrink_axis_mask = 0
|
|
906
|
-
if not is_dynamic:
|
|
907
|
-
strides_v, shrink_axis_mask = const_utils.get_stride_info_from_tuple(
|
|
908
|
-
data_shape, tuple_index)
|
|
909
|
-
else:
|
|
910
|
-
strides_v, shrink_axis_mask = _get_stride_info_from_tuple(
|
|
911
|
-
data, tuple_index)
|
|
912
|
-
begin_mask, end_mask = cal_tuple_slice_mask(data_shape, tuple_index)
|
|
913
|
-
begin_v = strides_v['begin']
|
|
914
|
-
end_v = strides_v['end']
|
|
915
|
-
step_v = strides_v['step']
|
|
916
|
-
return strided_slice(data, begin_v, end_v, step_v, begin_mask, end_mask, 0, 0, shrink_axis_mask)
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
@_primexpr
|
|
920
|
-
def _tensor_getitem_by_tuple_parse_bool_tensor_index(index, tuple_index_new, tensor_indexes,
|
|
921
|
-
tensor_positions_new):
|
|
922
|
-
""" parse index of bool tensor type """
|
|
923
|
-
indices = index.nonzero()
|
|
924
|
-
if indices.shape[0] == 0:
|
|
925
|
-
return None, tensor_indexes, tensor_positions_new
|
|
926
|
-
indices = F.cast(indices, mstype.int64)
|
|
927
|
-
indices = indices.T
|
|
928
|
-
for sub_index in indices:
|
|
929
|
-
tensor_positions_new.append(len(tuple_index_new))
|
|
930
|
-
tuple_index_new += (sub_index,)
|
|
931
|
-
tensor_indexes.append(sub_index)
|
|
932
|
-
return tuple_index_new, tensor_indexes, tensor_positions_new
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
def _tensor_getitem_by_tuple_parse_tensor_index(index, tuple_index_new, tensor_indexes, tensor_positions_new):
|
|
936
|
-
""" parse index of tensor type """
|
|
937
|
-
if F.dtype(index) in mstype.int_type:
|
|
938
|
-
tensor_index = F.cast(index, mstype.int64)
|
|
939
|
-
tensor_positions_new.append(len(tuple_index_new))
|
|
940
|
-
tuple_index_new += (tensor_index,)
|
|
941
|
-
tensor_indexes.append(tensor_index)
|
|
942
|
-
elif F.dtype(index) == mstype.bool_:
|
|
943
|
-
return _tensor_getitem_by_tuple_parse_bool_tensor_index(index, tuple_index_new, tensor_indexes,
|
|
944
|
-
tensor_positions_new)
|
|
945
|
-
else:
|
|
946
|
-
exp_msg = const_utils.gen_exception_msg(
|
|
947
|
-
"The tensor element in tuple index must be int or bool type, but got {}.", F.dtype(index))
|
|
948
|
-
const_utils.raise_index_error(exp_msg)
|
|
949
|
-
return tuple_index_new, tensor_indexes, tensor_positions_new
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
def _tensor_getitem_by_tuple(data, tuple_index, op_name):
|
|
953
|
-
"""Tensor getitem by a tuple of mixed tensor."""
|
|
954
|
-
slice_is_tensor = False
|
|
955
|
-
for item in tuple_index:
|
|
956
|
-
if isinstance(item, slice):
|
|
957
|
-
slice_is_tensor = isinstance(slice_get_item(item, "start"), Tensor) \
|
|
958
|
-
or isinstance(slice_get_item(item, "stop"), Tensor) \
|
|
959
|
-
or isinstance(slice_get_item(item, "step"), Tensor)
|
|
960
|
-
if slice_is_tensor:
|
|
961
|
-
const_utils.raise_index_error("Not supported when slice has tensor")
|
|
962
|
-
|
|
963
|
-
indexes_types = hyper_map(toptypeof, tuple_index)
|
|
964
|
-
slice_positions, _, _, int_positions, _, tensor_positions, sequence_positions = \
|
|
965
|
-
const_utils.get_pos_of_indexes_types(indexes_types, op_name)
|
|
966
|
-
data_shape = F.shape(data)
|
|
967
|
-
tensor_indexes, slice_indexes = [], []
|
|
968
|
-
tuple_index_new, slice_shapes = (), ()
|
|
969
|
-
slice_positions_new, tensor_positions_new = [], []
|
|
970
|
-
for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
|
|
971
|
-
if i in int_positions:
|
|
972
|
-
int_index = const_utils.check_range(index, dim_size)
|
|
973
|
-
tensor_index = F.scalar_to_tensor(int_index, mstype.int64)
|
|
974
|
-
if F.is_sequence_value_unknown(data_shape):
|
|
975
|
-
tensor_index = _scalar_to_tensor(int_index)
|
|
976
|
-
tensor_index = F.cast(tensor_index, mstype.int64)
|
|
977
|
-
tensor_positions_new.append(len(tuple_index_new))
|
|
978
|
-
tuple_index_new += (tensor_index,)
|
|
979
|
-
tensor_indexes.append(tensor_index)
|
|
980
|
-
elif i in sequence_positions:
|
|
981
|
-
tensor_index = const_utils.sequence_to_index(index, dim_size)
|
|
982
|
-
if tensor_index is False:
|
|
983
|
-
const_utils.raise_index_error("The sequence element(tuple/list) in tuple index can't be empty.")
|
|
984
|
-
tensor_positions_new.append(len(tuple_index_new))
|
|
985
|
-
tuple_index_new += (tensor_index,)
|
|
986
|
-
tensor_indexes.append(tensor_index)
|
|
987
|
-
elif i in tensor_positions:
|
|
988
|
-
tuple_index_new, tensor_indexes, tensor_positions_new = \
|
|
989
|
-
_tensor_getitem_by_tuple_parse_tensor_index(index, tuple_index_new,
|
|
990
|
-
tensor_indexes, tensor_positions_new)
|
|
991
|
-
if tuple_index_new is None:
|
|
992
|
-
return Tensor([])
|
|
993
|
-
elif i in slice_positions:
|
|
994
|
-
slice_ele_list_index = const_utils.transform_slice_to_ele_list(index, dim_size)
|
|
995
|
-
slice_shapes += (len(slice_ele_list_index),)
|
|
996
|
-
slice_positions_new.append(len(tuple_index_new))
|
|
997
|
-
tuple_index_new += (slice_ele_list_index,)
|
|
998
|
-
slice_indexes.append(slice_ele_list_index)
|
|
999
|
-
tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
|
|
1000
|
-
broadcast_shape, index_tensor_new_shape, final_shape, fancy_position = \
|
|
1001
|
-
const_utils.generate_index_info_from_tuple_of_mixed_tensors(tensor_positions_new, tensor_indexes_shapes,
|
|
1002
|
-
slice_shapes, op_name)
|
|
1003
|
-
|
|
1004
|
-
tuple_index_len = len(tuple_index)
|
|
1005
|
-
if 0 in final_shape + data_shape:
|
|
1006
|
-
if tuple_index_len < len(data_shape):
|
|
1007
|
-
final_shape = final_shape + data_shape[tuple_index_len:]
|
|
1008
|
-
return const_utils.make_tensor([], data.dtype, final_shape)
|
|
1009
|
-
|
|
1010
|
-
final_index_tensors = []
|
|
1011
|
-
slice_cnt = 0
|
|
1012
|
-
for i, index in enumerate(tuple_index_new):
|
|
1013
|
-
if i in tensor_positions_new:
|
|
1014
|
-
transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape,
|
|
1015
|
-
index)
|
|
1016
|
-
final_index_tensors.append(transform_tensor)
|
|
1017
|
-
elif i in slice_positions_new:
|
|
1018
|
-
slice_index_tensor = convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape,
|
|
1019
|
-
slice_shapes, fancy_position)
|
|
1020
|
-
final_index_tensors.append(slice_index_tensor)
|
|
1021
|
-
slice_cnt += 1
|
|
1022
|
-
|
|
1023
|
-
indices = stack(final_index_tensors)
|
|
1024
|
-
result = F.gather_nd(data, indices)
|
|
1025
|
-
return result
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
769
|
def _generate_indices_from_tuple_of_tensor(tuple_index, op_name):
|
|
1029
770
|
"""Generate an indices tensor from a tuple of tensor."""
|
|
1030
771
|
indexes_types = hyper_map(F.dtype, tuple_index)
|
|
@@ -1116,8 +857,15 @@ def sequence_to_tensor(value, dtype):
|
|
|
1116
857
|
|
|
1117
858
|
if value_elements_type == const_utils.ALL_TENSOR:
|
|
1118
859
|
value = F.stack(value).astype(dtype)
|
|
1119
|
-
elif value_elements_type == const_utils.NO_TENSOR
|
|
1120
|
-
|
|
860
|
+
elif value_elements_type == const_utils.NO_TENSOR:
|
|
861
|
+
if isinstance(value, list):
|
|
862
|
+
value = tuple(value)
|
|
863
|
+
|
|
864
|
+
if dtype == mstype.float16:
|
|
865
|
+
value = TupleToTensor()(value, mstype.float32)
|
|
866
|
+
value = F.cast(value, dtype)
|
|
867
|
+
else:
|
|
868
|
+
value = TupleToTensor()(value, dtype)
|
|
1121
869
|
else:
|
|
1122
870
|
new_value = ()
|
|
1123
871
|
for ele in value:
|
|
@@ -1138,57 +886,31 @@ def _generate_updates_from_sequence(data, index, value, op_type):
|
|
|
1138
886
|
def _generate_updates_from_tensor(data, index, value, op_type):
|
|
1139
887
|
"""Generate an updates tensor from a tensor."""
|
|
1140
888
|
value = value.astype(data.dtype)
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type, True)
|
|
1145
|
-
updates = ops.broadcast_to(value, updates_shape)
|
|
1146
|
-
return updates
|
|
1147
|
-
updates_shape = const_utils.generate_updates_shape(data.shape, index.shape, op_type, False)
|
|
1148
|
-
need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value.shape)
|
|
1149
|
-
if need_broadcast:
|
|
1150
|
-
return _broadcast(updates_shape, value)
|
|
1151
|
-
return value
|
|
889
|
+
updates_shape = const_utils.generate_updates_shape(data.shape, index.shape, op_type)
|
|
890
|
+
updates = ops.broadcast_to(value, updates_shape)
|
|
891
|
+
return updates
|
|
1152
892
|
|
|
1153
893
|
|
|
1154
894
|
# Tensor getitem implementations are above this line, setitem implementations below.
|
|
1155
895
|
|
|
1156
|
-
def
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
return tensor_setitem_by_tensor_with_tensor(self, index, value)
|
|
1161
|
-
return tensor_setitem_by_tensor_with_sequence(self, index, value)
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
def tensor_setitem_by_tuple(self, index, value):
|
|
1165
|
-
index = convert_tupleslice_to_tensor(index)
|
|
1166
|
-
if isinstance(value, (int, float, bool)):
|
|
1167
|
-
index = format_tuple_indices(index)
|
|
1168
|
-
return tensor_setitem_by_tuple_with_number(self, index, value)
|
|
1169
|
-
if isinstance(value, Tensor):
|
|
1170
|
-
return tensor_setitem_by_tuple_with_tensor(self, index, value)
|
|
1171
|
-
return tensor_setitem_by_tuple_with_sequence(self, index, value)
|
|
896
|
+
def _tensor_index_transfer(index, broadcast_shape, final_shape, new_shape):
|
|
897
|
+
"""Transform tuple index tensor to the required."""
|
|
898
|
+
if 0 in final_shape:
|
|
899
|
+
return F.fill(index.dtype, final_shape, 0)
|
|
1172
900
|
|
|
901
|
+
if broadcast_shape == ():
|
|
902
|
+
# broadcast_to () is not support on Ascend
|
|
903
|
+
item = index
|
|
904
|
+
else:
|
|
905
|
+
item = F.broadcast_to(index, broadcast_shape)
|
|
906
|
+
item = F.reshape(item, new_shape)
|
|
907
|
+
return F.broadcast_to(item, final_shape)
|
|
1173
908
|
|
|
1174
|
-
def tensor_setitem_by_number(self, index, value):
|
|
1175
|
-
if isinstance(value, (int, float, bool)):
|
|
1176
|
-
return tensor_setitem_by_number_with_number(self, index, value)
|
|
1177
|
-
if isinstance(value, Tensor):
|
|
1178
|
-
return tensor_setitem_by_number_with_tensor(self, index, value)
|
|
1179
|
-
return tensor_setitem_by_number_with_sequence(self, index, value)
|
|
1180
909
|
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
if not all_empty_tensor:
|
|
1186
|
-
x = F.broadcast_to(x, broadcast_shape)
|
|
1187
|
-
x = F.reshape(x, new_shape)
|
|
1188
|
-
x = F.broadcast_to(x, final_shape)
|
|
1189
|
-
return x
|
|
1190
|
-
item = _broadcast(broadcast_shape, x)
|
|
1191
|
-
return _broadcast(final_shape, F.reshape(item, new_shape))
|
|
910
|
+
def reshape_with_check(x, new_shape):
|
|
911
|
+
if isinstance(new_shape, Tensor):
|
|
912
|
+
new_shape = TensorToTuple()(new_shape)
|
|
913
|
+
return F.reshape(x, new_shape)
|
|
1192
914
|
|
|
1193
915
|
|
|
1194
916
|
class _TensorIndexSetitem(base.TensorIndexSetitem_):
|
|
@@ -1218,9 +940,10 @@ def tensor_setitem_by_slice(self, index, value):
|
|
|
1218
940
|
return self
|
|
1219
941
|
value = F.broadcast_to(value, value_shape)
|
|
1220
942
|
if not const_utils.is_ascend() and step == 1:
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
943
|
+
start = (start,)
|
|
944
|
+
stop = (stop,)
|
|
945
|
+
step = (step,)
|
|
946
|
+
return copy_slice(self, value, start, stop, step)
|
|
1224
947
|
return F.tensor_scatter_update(self, indices, value)
|
|
1225
948
|
|
|
1226
949
|
|
|
@@ -1236,14 +959,14 @@ def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
|
|
|
1236
959
|
"""Set a tensor item by an int tensor with a tensor."""
|
|
1237
960
|
if F.rank(index) == 0:
|
|
1238
961
|
index = F.expand_dims(index, -1)
|
|
1239
|
-
|
|
962
|
+
|
|
1240
963
|
data_shape = F.shape(data)
|
|
964
|
+
updates_shape = index.shape + data_shape[1:]
|
|
965
|
+
value = F.cast(value, F.dtype(data))
|
|
966
|
+
updates = ops.broadcast_to(value, updates_shape)
|
|
1241
967
|
first_val = data_shape[0]
|
|
1242
968
|
index = F.select(index < 0, index + first_val, index)
|
|
1243
969
|
index = F.expand_dims(index, -1)
|
|
1244
|
-
if F.rank(index) < 2:
|
|
1245
|
-
index = F.expand_dims(index, 0)
|
|
1246
|
-
updates = F.expand_dims(updates, 0)
|
|
1247
970
|
if is_parameter(data):
|
|
1248
971
|
F.scatter_nd_update(data, index, updates)
|
|
1249
972
|
return data
|
|
@@ -1255,8 +978,7 @@ def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
|
|
|
1255
978
|
index = index.reshape(const_utils.generate_padding_shape(index.shape, len(data.shape)))
|
|
1256
979
|
index = F.broadcast_to(index, data.shape)
|
|
1257
980
|
value = F.cast(value, F.dtype(data))
|
|
1258
|
-
|
|
1259
|
-
value = value.unsqueeze(-1)
|
|
981
|
+
value = value.reshape(const_utils.generate_padding_shape(value.shape, len(data.shape)))
|
|
1260
982
|
value = F.broadcast_to(value, data.shape)
|
|
1261
983
|
result = F.select(index, value, data)
|
|
1262
984
|
return result
|
|
@@ -1269,8 +991,6 @@ def tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
|
|
|
1269
991
|
if tensor_dtype == const_utils.INT_:
|
|
1270
992
|
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
|
|
1271
993
|
|
|
1272
|
-
if F.is_sequence_value_unknown(F.shape(data)):
|
|
1273
|
-
return tensor_setitem_by_tuple_with_tensor(data, (index,), value_tensor.astype(data.dtype))
|
|
1274
994
|
return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
|
|
1275
995
|
|
|
1276
996
|
|
|
@@ -1281,33 +1001,8 @@ def tensor_setitem_by_tensor_with_number(data, index, value):
|
|
|
1281
1001
|
|
|
1282
1002
|
def tensor_setitem_by_tensor_with_sequence(data, index, value):
|
|
1283
1003
|
"""Assigns the tensor by tensor with tuple value."""
|
|
1284
|
-
index_dtype = F.dtype(index)
|
|
1285
|
-
if index_dtype in (mstype.int32, mstype.int64):
|
|
1286
|
-
return _tensor_setitem_by_tensor_with_sequence(data, index, value)
|
|
1287
|
-
if index_dtype == mstype.bool_:
|
|
1288
|
-
return _tensor_setitem_by_bool_tensor_with_sequence(data, index, value)
|
|
1289
|
-
exp_msg = const_utils.gen_exception_msg("The tensor index must be int or bool type, but got {}.", index_dtype)
|
|
1290
|
-
const_utils.raise_index_error(exp_msg)
|
|
1291
|
-
return None
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
def _tensor_setitem_by_tensor_with_sequence(data, index, value):
|
|
1295
|
-
"""Set a tensor item by a tensor with a tuple."""
|
|
1296
|
-
updates = _generate_updates_from_sequence(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
|
|
1297
|
-
index = F.expand_dims(index, -1)
|
|
1298
|
-
return F.tensor_scatter_update(data, index, updates)
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
def _tensor_setitem_by_bool_tensor_with_sequence(data, index, value):
|
|
1302
|
-
"""Set a tensor item by a bool tensor with a tuple."""
|
|
1303
1004
|
value = sequence_to_tensor(value, F.dtype(data))
|
|
1304
|
-
return
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
def tensor_setitem_by_slice_with_number(data, input_slice, value):
|
|
1308
|
-
"""Givens a scalar assign to tensor by slice"""
|
|
1309
|
-
value = F.cast(value, F.dtype(data))
|
|
1310
|
-
return tensor_setitem_by_slice_with_tensor(data, input_slice, value)
|
|
1005
|
+
return tensor_setitem_by_tensor_with_tensor(data, index, value)
|
|
1311
1006
|
|
|
1312
1007
|
|
|
1313
1008
|
def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
|
|
@@ -1316,78 +1011,14 @@ def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
|
|
|
1316
1011
|
return tensor_setitem_by_tuple_with_tensor(data, tuple_index, value)
|
|
1317
1012
|
|
|
1318
1013
|
|
|
1319
|
-
def
|
|
1320
|
-
"""
|
|
1321
|
-
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
stop_tensor = stack((stop,))
|
|
1325
|
-
step_tensor = stack((step,))
|
|
1326
|
-
dim0_size = stop_tensor - start_tensor
|
|
1327
|
-
if dim0_size <= 0:
|
|
1328
|
-
return data
|
|
1329
|
-
if dim0_size >= data_shape[0]:
|
|
1330
|
-
dim0_size = data_shape[0:1]
|
|
1331
|
-
value_shape = P.Concat(-1)((dim0_size, data_shape[1:]))
|
|
1332
|
-
value = ops.broadcast_to(value, value_shape)
|
|
1333
|
-
return copy_slice(data, value.astype(data.dtype), start_tensor, stop_tensor, step_tensor)
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
|
|
1337
|
-
"""Assigns a tensor value to the tensor by slice."""
|
|
1338
|
-
result = None
|
|
1339
|
-
check_result = const_utils.check_tensor_setitem_index(input_slice)
|
|
1340
|
-
if check_result:
|
|
1341
|
-
data_shape = F.shape(data)
|
|
1342
|
-
step = const_utils.get_step_from_slice(input_slice)
|
|
1343
|
-
if step == 1 and not const_utils.is_ascend():
|
|
1344
|
-
if F.is_sequence_value_unknown(data_shape):
|
|
1345
|
-
return tensor_copy_slice_from_slice(data, input_slice, value)
|
|
1346
|
-
start, stop, step = const_utils.normalize_slice(input_slice, data.shape[0])
|
|
1347
|
-
dim0_size = stop - start
|
|
1348
|
-
if dim0_size <= 0:
|
|
1349
|
-
return data
|
|
1350
|
-
value_shape = (dim0_size,) + const_utils.tuple_slice(data.shape, 1, None)
|
|
1351
|
-
value = _broadcast(value_shape, value)
|
|
1352
|
-
return copy_slice(data, value.astype(data.dtype), (start,), (stop,), (step,))
|
|
1353
|
-
if F.is_sequence_value_unknown(data_shape):
|
|
1354
|
-
const_utils.raise_unimplemented_error(
|
|
1355
|
-
"Not supported to take the subscript of dynamic shape tensor slice setitem")
|
|
1356
|
-
indices = const_utils.slice2indices(input_slice, data_shape)
|
|
1357
|
-
if indices is False:
|
|
1358
|
-
return data
|
|
1359
|
-
value_shape = const_utils.tuple_slice(F.shape(indices), None, -1)
|
|
1360
|
-
value = _broadcast(value_shape, value)
|
|
1361
|
-
result = F.tensor_scatter_update(data, indices, value.astype(F.dtype(data)))
|
|
1362
|
-
return result
|
|
1363
|
-
|
|
1364
|
-
|
|
1365
|
-
def tensor_setitem_by_slice_with_sequence(data, input_slice, value):
|
|
1366
|
-
"""Assigns a list/tuple value to the tensor by slice."""
|
|
1367
|
-
value = _generate_updates_from_sequence(data, input_slice, value, const_utils.SET_ITEM_BY_NON_TENSOR)
|
|
1368
|
-
return tensor_setitem_by_slice_with_tensor(data, input_slice, value)
|
|
1014
|
+
def tensor_setitem_by_list(data, index, value):
|
|
1015
|
+
"""list indices will be converted to tuple or tensor based on its contents."""
|
|
1016
|
+
index_tensor = _convert_list_index_to_tensor(index)
|
|
1017
|
+
if index_tensor is not None:
|
|
1018
|
+
return tensor_setitem_by_tensor_with_tensor(data, index_tensor, value)
|
|
1369
1019
|
|
|
1020
|
+
return tensor_setitem_by_tuple_with_tensor(data, tuple(index), value)
|
|
1370
1021
|
|
|
1371
|
-
def tensor_copy_slice_from_tuple(data, tuple_index, value):
|
|
1372
|
-
"""using TensorCopySlices by fixed model tuple."""
|
|
1373
|
-
data_shape = F.dyn_shape(data)
|
|
1374
|
-
dim1_start, dim1_stop, _ = get_slice_stride(tuple_index[1], data_shape[1])
|
|
1375
|
-
if dim1_stop - dim1_start <= 0:
|
|
1376
|
-
return data
|
|
1377
|
-
dim0_start = _scalar_to_tensor(tuple_index[0])
|
|
1378
|
-
dim0_stop = dim0_start + const_utils.make_tensor(1)
|
|
1379
|
-
start = (dim0_start, dim1_start)
|
|
1380
|
-
stop = (dim0_stop, dim1_stop)
|
|
1381
|
-
step = (const_utils.make_tensor(1), const_utils.make_tensor(1))
|
|
1382
|
-
start_tensor = stack(start)
|
|
1383
|
-
stop_tensor = stack(stop)
|
|
1384
|
-
step_tensor = stack(step)
|
|
1385
|
-
dim1_size = stack((dim1_stop - dim1_start,))
|
|
1386
|
-
if dim1_size > data_shape[1]:
|
|
1387
|
-
dim1_size = data_shape[1:2]
|
|
1388
|
-
value_shape = P.Concat(-1)((dim1_size, data_shape[2:]))
|
|
1389
|
-
value = ops.broadcast_to(value, value_shape)
|
|
1390
|
-
return copy_slice(data, value.astype(data.dtype), start_tensor, stop_tensor, step_tensor)
|
|
1391
1022
|
|
|
1392
1023
|
|
|
1393
1024
|
class _PreSetitemByTuple(base.PreSetitemByTuple_):
|
|
@@ -1436,50 +1067,28 @@ class _HandleBoolTensor(base.HandleBoolTensor_):
|
|
|
1436
1067
|
_handle_bool_tensor = _HandleBoolTensor('handle_bool_tensor')
|
|
1437
1068
|
|
|
1438
1069
|
|
|
1439
|
-
class _HandleScalarTensorIndex(base.HandleScalarTensorIndex_):
|
|
1440
|
-
"""
|
|
1441
|
-
Getting item of Tensor.
|
|
1442
|
-
|
|
1443
|
-
Args:
|
|
1444
|
-
data (Tensor): A tuple to be sliced.
|
|
1445
|
-
index: Index of tensor.
|
|
1446
|
-
|
|
1447
|
-
Returns:
|
|
1448
|
-
Type is the same as the element type of data.
|
|
1449
|
-
"""
|
|
1450
|
-
|
|
1451
|
-
def __init__(self, name):
|
|
1452
|
-
"""Initialize _HandleBoolTensor."""
|
|
1453
|
-
base.HandleScalarTensorIndex_.__init__(self, name)
|
|
1454
|
-
|
|
1455
|
-
def __call__(self, *args):
|
|
1456
|
-
pass
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
_handle_scalar_tensor_index = _HandleScalarTensorIndex('handle_scalar_tensor_index')
|
|
1460
|
-
|
|
1461
|
-
|
|
1462
1070
|
def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
|
1463
1071
|
"""Assigns the tensor by tuple with tensor value."""
|
|
1464
1072
|
if const_utils.use_copy_slice(tuple_index) and not const_utils.is_ascend():
|
|
1465
|
-
if F.is_sequence_value_unknown(F.shape(data)):
|
|
1466
|
-
return tensor_copy_slice_from_tuple(data, tuple_index, value)
|
|
1467
1073
|
dim1_start, dim1_stop, _ = const_utils.normalize_slice(
|
|
1468
1074
|
tuple_index[1], data.shape[1])
|
|
1075
|
+
if isinstance(dim1_start, Tensor):
|
|
1076
|
+
dim1_start = int(dim1_start)
|
|
1077
|
+
if isinstance(dim1_stop, Tensor):
|
|
1078
|
+
dim1_stop = int(dim1_stop)
|
|
1469
1079
|
if dim1_stop - dim1_start <= 0:
|
|
1470
1080
|
return data
|
|
1471
1081
|
dim0_start = tuple_index[0] if tuple_index[0] >= 0 else tuple_index[0] + data.shape[0]
|
|
1472
1082
|
start = (dim0_start, dim1_start)
|
|
1473
1083
|
stop = (dim0_start + 1, dim1_stop)
|
|
1474
1084
|
step = (1, 1)
|
|
1475
|
-
value_shape = (dim1_stop - dim1_start,) +
|
|
1476
|
-
|
|
1477
|
-
value = _broadcast(value_shape, value)
|
|
1085
|
+
value_shape = (dim1_stop - dim1_start,) + data.shape[2:]
|
|
1086
|
+
value = F.broadcast_to(value, value_shape)
|
|
1478
1087
|
return copy_slice(data, value.astype(data.dtype), start, stop, step)
|
|
1479
1088
|
tuple_index, _, non_zero_shapes = _handle_bool_tensor(tuple_index)
|
|
1480
1089
|
|
|
1481
1090
|
for non_zero_shape in non_zero_shapes:
|
|
1482
|
-
if
|
|
1091
|
+
if 0 in non_zero_shape:
|
|
1483
1092
|
return data
|
|
1484
1093
|
value = value.astype(data.dtype)
|
|
1485
1094
|
special_index, tuple_index, new_value_shape, idx_advanced, _broadcast_data_shape \
|
|
@@ -1512,17 +1121,19 @@ def tensor_itemset_by_tuple_with_tensor(data, tuple_index, value):
|
|
|
1512
1121
|
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
|
1513
1122
|
|
|
1514
1123
|
if const_utils.use_copy_slice(tuple_index) and not const_utils.is_ascend():
|
|
1515
|
-
if F.is_sequence_value_unknown(F.shape(data)):
|
|
1516
|
-
return tensor_copy_slice_from_tuple(data, tuple_index, value)
|
|
1517
1124
|
dim1_start, dim1_stop, _ = const_utils.normalize_slice(tuple_index[1], data.shape[1])
|
|
1125
|
+
if isinstance(dim1_start, Tensor):
|
|
1126
|
+
dim1_start = int(dim1_start)
|
|
1127
|
+
if isinstance(dim1_stop, Tensor):
|
|
1128
|
+
dim1_stop = int(dim1_stop)
|
|
1518
1129
|
if dim1_stop - dim1_start <= 0:
|
|
1519
1130
|
return data
|
|
1520
1131
|
dim0_start = tuple_index[0] if tuple_index[0] >= 0 else tuple_index[0] + data.shape[0]
|
|
1521
1132
|
start = (dim0_start, dim1_start)
|
|
1522
1133
|
stop = (dim0_start + 1, dim1_stop)
|
|
1523
1134
|
step = (1, 1)
|
|
1524
|
-
value_shape = (dim1_stop - dim1_start,) +
|
|
1525
|
-
value =
|
|
1135
|
+
value_shape = (dim1_stop - dim1_start,) + data.shape[2:]
|
|
1136
|
+
value = F.broadcast_to(value, value_shape)
|
|
1526
1137
|
return copy_slice(data, value.astype(data.dtype), start, stop, step)
|
|
1527
1138
|
tuple_index, value, idx_advanced = remove_expanded_dims(tuple_index, F.shape(data), value)
|
|
1528
1139
|
|
|
@@ -1545,49 +1156,45 @@ def tensor_itemset_by_tuple_with_tensor(data, tuple_index, value):
|
|
|
1545
1156
|
|
|
1546
1157
|
|
|
1547
1158
|
def tensor_setitem_by_tuple_with_sequence(data, tuple_index, value):
|
|
1548
|
-
value =
|
|
1159
|
+
value = sequence_to_tensor(value, F.dtype(data))
|
|
1549
1160
|
return tensor_setitem_by_tuple_with_tensor(data, tuple_index, value)
|
|
1550
1161
|
|
|
1551
1162
|
|
|
1552
1163
|
def tensor_setitem_by_number_with_number(data, index, value):
|
|
1553
1164
|
"""Assigns the tensor by number with number value."""
|
|
1554
|
-
|
|
1555
|
-
|
|
1165
|
+
data_shape = F.shape(data)
|
|
1166
|
+
dim_size = data_shape[0]
|
|
1167
|
+
if index < 0:
|
|
1168
|
+
index += dim_size
|
|
1169
|
+
if index < -dim_size or index >= dim_size:
|
|
1170
|
+
raise IndexError(f'index {index} is out of bounds for axis 0 with size {dim_size}')
|
|
1171
|
+
index = F.cast(index, mstype.int64)
|
|
1172
|
+
index = F.reshape(index, (1, 1))
|
|
1173
|
+
|
|
1174
|
+
updates = F.cast(value, data.dtype)
|
|
1175
|
+
updates_shape = (1,) + data_shape[1:]
|
|
1176
|
+
updates = ops.broadcast_to(updates, updates_shape)
|
|
1177
|
+
|
|
1178
|
+
if is_parameter(data):
|
|
1179
|
+
F.scatter_nd_update(data, index, updates)
|
|
1180
|
+
return data
|
|
1181
|
+
return F.tensor_scatter_update(data, index, updates)
|
|
1556
1182
|
|
|
1557
1183
|
|
|
1558
1184
|
def tensor_setitem_by_number_with_sequence(data, index, value):
|
|
1559
1185
|
"""Assigns a list/tuple value to the tensor by slice."""
|
|
1560
|
-
value =
|
|
1186
|
+
value = sequence_to_tensor(value, F.dtype(data))
|
|
1561
1187
|
return tensor_setitem_by_number_with_tensor(data, index, value)
|
|
1562
1188
|
|
|
1563
1189
|
|
|
1564
1190
|
def tensor_setitem_by_number_with_tensor(data, index, value):
|
|
1565
|
-
|
|
1566
|
-
data_shape = F.shape(data)
|
|
1567
|
-
if F.is_sequence_value_unknown(data_shape):
|
|
1568
|
-
index = _scalar_to_tensor(index)
|
|
1569
|
-
index = F.expand_dims(index, -1)
|
|
1570
|
-
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value)
|
|
1571
|
-
|
|
1572
|
-
dim_size = data_shape[0]
|
|
1573
|
-
if index < -dim_size or index >= dim_size:
|
|
1574
|
-
raise IndexError(f'index {index} is out of bounds for axis 0 with size {dim_size}')
|
|
1575
|
-
index = const_utils.int_to_index(index, data_shape)
|
|
1576
|
-
value_shape = const_utils.tuple_slice(F.shape(index), None, -1)
|
|
1577
|
-
value = _broadcast(value_shape, value.astype(F.dtype(data)))
|
|
1578
|
-
if is_parameter(data):
|
|
1579
|
-
F.scatter_nd_update(data, index, value)
|
|
1580
|
-
return data
|
|
1581
|
-
return F.tensor_scatter_update(data, index, value)
|
|
1191
|
+
return tensor_setitem_by_number_with_number(data, index, value)
|
|
1582
1192
|
|
|
1583
1193
|
|
|
1584
1194
|
def tensor_setitem_by_ellipsis_with_number(data, value):
|
|
1585
1195
|
"""Assigns the tensor by ellipsis with number value."""
|
|
1586
1196
|
data_shape = F.shape(data)
|
|
1587
1197
|
data_dtype = F.dtype(data)
|
|
1588
|
-
if F.is_sequence_value_unknown(data_shape):
|
|
1589
|
-
value = F.cast(value, F.dtype(data))
|
|
1590
|
-
return tensor_setitem_by_ellipsis_with_tensor(data, value)
|
|
1591
1198
|
return F.fill(data_dtype, data_shape, value)
|
|
1592
1199
|
|
|
1593
1200
|
|
|
@@ -1597,21 +1204,20 @@ def tensor_setitem_by_ellipsis_with_tensor(data, value):
|
|
|
1597
1204
|
data_dtype = F.dtype(data)
|
|
1598
1205
|
value = value.astype(data_dtype)
|
|
1599
1206
|
|
|
1600
|
-
if F.is_sequence_value_unknown(data_shape):
|
|
1601
|
-
data_shape = F.dyn_shape(data)
|
|
1602
|
-
data = ops.broadcast_to(value, data_shape)
|
|
1603
|
-
return data
|
|
1604
1207
|
value_shape = F.shape(value)
|
|
1605
|
-
|
|
1208
|
+
|
|
1209
|
+
if len(value_shape) > len(data_shape):
|
|
1210
|
+
source_shape = data_shape
|
|
1211
|
+
else:
|
|
1212
|
+
source_shape = value_shape
|
|
1606
1213
|
value = F.reshape(value, source_shape)
|
|
1607
|
-
|
|
1608
|
-
data = F.cast(value, data_dtype)
|
|
1214
|
+
data = F.broadcast_to(value, data_shape)
|
|
1609
1215
|
return data
|
|
1610
1216
|
|
|
1611
1217
|
|
|
1612
1218
|
def tensor_setitem_by_ellipsis_with_sequence(data, value):
|
|
1613
1219
|
"""Assigns a list/tuple value to the tensor by ellipsis."""
|
|
1614
|
-
value =
|
|
1220
|
+
value = sequence_to_tensor(value, F.dtype(data))
|
|
1615
1221
|
return tensor_setitem_by_ellipsis_with_tensor(data, value)
|
|
1616
1222
|
|
|
1617
1223
|
|
|
@@ -1622,23 +1228,18 @@ def tensor_setitem_by_bool(data, index, value):
|
|
|
1622
1228
|
if not index:
|
|
1623
1229
|
data_shape = (0,) + data_shape
|
|
1624
1230
|
if isinstance(value, (list, tuple)):
|
|
1625
|
-
value =
|
|
1626
|
-
|
|
1627
|
-
value =
|
|
1628
|
-
|
|
1629
|
-
value = const_utils.make_tensor(value, mstype.float32)
|
|
1630
|
-
|
|
1631
|
-
if F.is_sequence_value_unknown(data_shape) and index:
|
|
1632
|
-
data_shape = F.dyn_shape(data)
|
|
1633
|
-
value = value.astype(data_dtype)
|
|
1634
|
-
data = ops.broadcast_to(value, data_shape)
|
|
1635
|
-
return data
|
|
1636
|
-
value_shape = F.shape(value)
|
|
1637
|
-
source_shape = const_utils.get_source_shape(data_shape, value_shape)
|
|
1231
|
+
value = sequence_to_tensor(value, data_dtype)
|
|
1232
|
+
else:
|
|
1233
|
+
value = F.cast(value, data_dtype)
|
|
1234
|
+
|
|
1638
1235
|
if index:
|
|
1236
|
+
value_shape = F.shape(value)
|
|
1237
|
+
if len(value_shape) > len(data_shape):
|
|
1238
|
+
source_shape = data_shape
|
|
1239
|
+
else:
|
|
1240
|
+
source_shape = value_shape
|
|
1639
1241
|
value = F.reshape(value, source_shape)
|
|
1640
|
-
|
|
1641
|
-
data = F.cast(value, data_dtype)
|
|
1242
|
+
data = F.broadcast_to(value, data_shape)
|
|
1642
1243
|
return data
|
|
1643
1244
|
|
|
1644
1245
|
|
|
@@ -1651,33 +1252,6 @@ def tensor_in_sequence(x, y):
|
|
|
1651
1252
|
return result
|
|
1652
1253
|
|
|
1653
1254
|
|
|
1654
|
-
def format_list_indices(list_indices, length):
|
|
1655
|
-
"""Convert list indices to tensor or tuple indices based on its contents."""
|
|
1656
|
-
indices_types = hyper_map(F.typeof, list_indices)
|
|
1657
|
-
# If eyery element in list is bool, it's treated as 1-D bool tensor.
|
|
1658
|
-
# If every element in list is int(not all bool), it's treated as int tensor.
|
|
1659
|
-
if const_utils.judge_indexes_types(indices_types, mstype.int_type + (mstype.bool_,)):
|
|
1660
|
-
if not F.isconstant(length):
|
|
1661
|
-
return const_utils.sequence_to_index(list_indices, None)
|
|
1662
|
-
return const_utils.sequence_to_index(list_indices, length)
|
|
1663
|
-
# If list contains other types(.../list/tuple/None), it's treated as a tuple
|
|
1664
|
-
return const_utils.deep_tuple(list_indices)
|
|
1665
|
-
|
|
1666
|
-
|
|
1667
|
-
def format_tuple_indices(tuple_indices):
|
|
1668
|
-
"""
|
|
1669
|
-
Format tuple indices by unpacking high-dimension tuple and removing expand
|
|
1670
|
-
dimension signs(Bool and None).
|
|
1671
|
-
"""
|
|
1672
|
-
res = ()
|
|
1673
|
-
for i in tuple_indices:
|
|
1674
|
-
if isinstance(i, (list, tuple)):
|
|
1675
|
-
res += (const_utils.unpack(i),)
|
|
1676
|
-
else:
|
|
1677
|
-
res += (i,)
|
|
1678
|
-
return res
|
|
1679
|
-
|
|
1680
|
-
|
|
1681
1255
|
@_primexpr
|
|
1682
1256
|
def remove_expanded_dims_parse_bool_tensor_index(index_out, indices_out, shapes, cur_dim):
|
|
1683
1257
|
""" Parse bool tensor index """
|
|
@@ -1830,7 +1404,7 @@ def reduce_(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None,
|
|
|
1830
1404
|
return reduce_fn(a, axes).astype(dtype)
|
|
1831
1405
|
|
|
1832
1406
|
|
|
1833
|
-
tensor_operator_registry
|
|
1407
|
+
setattr(tensor_operator_registry, "reduce", reduce_)
|
|
1834
1408
|
|
|
1835
1409
|
|
|
1836
1410
|
def check_indices(dims, indices, mode, allow_negative_index=True):
|
|
@@ -1857,7 +1431,7 @@ def check_indices(dims, indices, mode, allow_negative_index=True):
|
|
|
1857
1431
|
return clipped
|
|
1858
1432
|
|
|
1859
1433
|
|
|
1860
|
-
tensor_operator_registry
|
|
1434
|
+
setattr(tensor_operator_registry, 'check_indices', check_indices)
|
|
1861
1435
|
|
|
1862
1436
|
|
|
1863
1437
|
def convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape, slice_shapes, fancy_position):
|