mindspore 1.10.0__cp38-cp38-win_amd64.whl → 2.0.0rc1__cp38-cp38-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/ConcurrencyCheck.dll +0 -0
- mindspore/CppBuildInsights.dll +0 -0
- mindspore/CppCoreCheck.dll +0 -0
- mindspore/EnumIndex.dll +0 -0
- mindspore/EspXEngine.dll +0 -0
- mindspore/HResultCheck.dll +0 -0
- mindspore/KernelTraceControl.dll +0 -0
- mindspore/LocalESPC.dll +0 -0
- mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
- mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
- mindspore/VariantClear.dll +0 -0
- mindspore/__init__.py +9 -4
- mindspore/_c_dataengine.cp38-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp38-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp38-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/builtin_operations.py +32 -4
- mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +17 -2
- mindspore/_extends/parse/parser.py +193 -34
- mindspore/_extends/parse/resources.py +7 -8
- mindspore/_extends/parse/standard_method.py +1780 -435
- mindspore/_extends/parse/trope.py +3 -1
- mindspore/amp.py +53 -58
- mindspore/atlprov.dll +0 -0
- mindspore/boost/adasum.py +3 -2
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +46 -26
- mindspore/boost/dim_reduce.py +6 -5
- mindspore/boost/grad_accumulation.py +2 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/cfgpersist.dll +0 -0
- mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
- mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
- mindspore/common/__init__.py +11 -10
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +57 -0
- mindspore/common/api.py +582 -297
- mindspore/common/dtype.py +66 -18
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +38 -1
- mindspore/common/jit_config.py +25 -13
- mindspore/common/mutable.py +53 -24
- mindspore/common/parameter.py +60 -37
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +927 -0
- mindspore/common/tensor.py +1627 -3900
- mindspore/communication/__init__.py +10 -5
- mindspore/communication/_comm_helper.py +78 -214
- mindspore/communication/_hccl_management.py +2 -1
- mindspore/communication/management.py +136 -47
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +291 -56
- mindspore/d3dcompiler_47.dll +0 -0
- mindspore/dataset/__init__.py +12 -8
- mindspore/dataset/audio/__init__.py +9 -9
- mindspore/dataset/audio/transforms.py +1090 -228
- mindspore/dataset/audio/utils.py +87 -39
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +17 -15
- mindspore/dataset/core/config.py +246 -17
- mindspore/dataset/core/py_util_helpers.py +4 -3
- mindspore/dataset/core/validator_helpers.py +10 -10
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +9 -9
- mindspore/dataset/engine/datasets.py +648 -477
- mindspore/dataset/engine/datasets_audio.py +165 -167
- mindspore/dataset/engine/datasets_standard_format.py +93 -67
- mindspore/dataset/engine/datasets_text.py +492 -342
- mindspore/dataset/engine/datasets_user_defined.py +85 -50
- mindspore/dataset/engine/datasets_vision.py +1224 -699
- mindspore/dataset/engine/graphdata.py +134 -69
- mindspore/dataset/engine/iterators.py +50 -9
- mindspore/dataset/engine/offload.py +52 -31
- mindspore/dataset/engine/samplers.py +27 -24
- mindspore/dataset/engine/serializer_deserializer.py +14 -15
- mindspore/dataset/engine/validators.py +213 -52
- mindspore/dataset/text/__init__.py +10 -8
- mindspore/dataset/text/transforms.py +152 -57
- mindspore/dataset/text/utils.py +98 -49
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +4 -2
- mindspore/dataset/transforms/c_transforms.py +11 -13
- mindspore/dataset/transforms/py_transforms.py +2 -2
- mindspore/dataset/transforms/py_transforms_util.py +10 -0
- mindspore/dataset/transforms/transforms.py +13 -15
- mindspore/dataset/transforms/validators.py +7 -7
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/browse_dataset.py +13 -13
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +8 -7
- mindspore/dataset/vision/c_transforms.py +125 -126
- mindspore/dataset/vision/py_transforms.py +37 -37
- mindspore/dataset/vision/py_transforms_util.py +23 -20
- mindspore/dataset/vision/transforms.py +316 -315
- mindspore/dataset/vision/utils.py +313 -17
- mindspore/dataset/vision/validators.py +6 -6
- mindspore/default_config.py +0 -1
- mindspore/dpcmi.dll +0 -0
- mindspore/{compression → experimental}/__init__.py +6 -5
- mindspore/experimental/map_parameter.py +275 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +70 -9
- mindspore/include/api/delegate.h +8 -1
- mindspore/include/api/dual_abi_helper.h +8 -24
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_group.h +68 -0
- mindspore/include/api/model_parallel_runner.h +17 -17
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +20 -4
- mindspore/include/api/status.h +7 -1
- mindspore/include/api/types.h +25 -21
- mindspore/include/api/visible.h +4 -0
- mindspore/include/c_api/model_c.h +5 -0
- mindspore/include/c_api/status_c.h +1 -1
- mindspore/include/dataset/config.h +1 -1
- mindspore/include/dataset/constants.h +14 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/include/dataset/vision.h +56 -117
- mindspore/include/dataset/vision_lite.h +102 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +28 -28
- mindspore/mindrecord/common/exceptions.py +2 -4
- mindspore/mindrecord/filereader.py +19 -1
- mindspore/mindrecord/filewriter.py +250 -88
- mindspore/mindrecord/mindpage.py +13 -13
- mindspore/mindrecord/shardheader.py +15 -15
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +29 -29
- mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
- mindspore/mindrecord/tools/csv_to_mr.py +4 -4
- mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
- mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/{libmindspore_backend.dll → mindspore_backend.dll} +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +1 -5
- mindspore/nn/cell.py +297 -234
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +17 -42
- mindspore/nn/layer/__init__.py +7 -4
- mindspore/nn/layer/activation.py +131 -88
- mindspore/nn/layer/basic.py +313 -613
- mindspore/nn/layer/channel_shuffle.py +103 -0
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +52 -6
- mindspore/nn/layer/conv.py +112 -43
- mindspore/nn/layer/dense.py +10 -9
- mindspore/nn/layer/embedding.py +36 -34
- mindspore/nn/layer/image.py +123 -27
- mindspore/nn/layer/math.py +108 -107
- mindspore/nn/layer/normalization.py +212 -366
- mindspore/nn/layer/padding.py +370 -42
- mindspore/nn/layer/pooling.py +1443 -219
- mindspore/nn/layer/rnn_cells.py +11 -16
- mindspore/nn/layer/rnns.py +38 -39
- mindspore/nn/layer/thor_layer.py +24 -25
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +9 -6
- mindspore/nn/loss/loss.py +678 -142
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
- mindspore/nn/optim/ada_grad.py +8 -8
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +18 -14
- mindspore/nn/optim/adam.py +429 -87
- mindspore/nn/optim/adamax.py +5 -6
- mindspore/nn/optim/adasum.py +10 -8
- mindspore/nn/optim/asgd.py +7 -7
- mindspore/nn/optim/ftrl.py +81 -11
- mindspore/nn/optim/lamb.py +7 -8
- mindspore/nn/optim/lars.py +4 -4
- mindspore/nn/optim/lazyadam.py +82 -7
- mindspore/nn/optim/momentum.py +8 -7
- mindspore/nn/optim/optimizer.py +19 -10
- mindspore/nn/optim/proximal_ada_grad.py +6 -5
- mindspore/nn/optim/rmsprop.py +3 -3
- mindspore/nn/optim/rprop.py +20 -16
- mindspore/nn/optim/sgd.py +21 -15
- mindspore/nn/optim/thor.py +23 -21
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -6
- mindspore/nn/probability/bijector/invert.py +4 -2
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/__init__.py +6 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
- mindspore/nn/probability/distribution/_utils/utils.py +11 -17
- mindspore/nn/probability/distribution/bernoulli.py +6 -6
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +9 -9
- mindspore/nn/probability/distribution/cauchy.py +8 -8
- mindspore/nn/probability/distribution/distribution.py +12 -6
- mindspore/nn/probability/distribution/exponential.py +5 -5
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +6 -5
- mindspore/nn/probability/distribution/gumbel.py +5 -5
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +0 -1
- mindspore/nn/probability/distribution/logistic.py +4 -5
- mindspore/nn/probability/distribution/normal.py +11 -15
- mindspore/nn/probability/distribution/poisson.py +6 -2
- mindspore/nn/probability/distribution/student_t.py +150 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +5 -5
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +8 -1
- mindspore/nn/wrap/cell_wrapper.py +55 -27
- mindspore/nn/wrap/grad_reducer.py +20 -11
- mindspore/nn/wrap/loss_scale.py +47 -30
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +46 -42
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +26 -19
- mindspore/numpy/utils.py +1 -8
- mindspore/numpy/utils_const.py +112 -62
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -3
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +209 -152
- mindspore/ops/_grad/grad_base.py +55 -17
- mindspore/ops/_grad/grad_clip_ops.py +11 -3
- mindspore/ops/_grad/grad_comm_ops.py +58 -47
- mindspore/ops/_grad/grad_implementations.py +21 -61
- mindspore/ops/_grad/grad_inner_ops.py +48 -6
- mindspore/ops/_grad/grad_math_ops.py +306 -161
- mindspore/ops/_grad/grad_nn_ops.py +192 -181
- mindspore/ops/_grad/grad_other_ops.py +1 -1
- mindspore/ops/_grad/grad_quant_ops.py +5 -5
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +15 -9
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
- mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
- mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
- mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
- mindspore/ops/_op_impl/__init__.py +3 -3
- mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
- mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
- mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/__init__.py +1 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
- mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -608
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/greater.py +2 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
- mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
- mindspore/ops/_op_impl/tbe/slice.py +26 -15
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +3 -2
- mindspore/ops/_register_for_op.py +11 -0
- mindspore/ops/_utils/__init__.py +1 -1
- mindspore/ops/_utils/utils.py +20 -41
- mindspore/ops/_vmap/__init__.py +2 -2
- mindspore/ops/_vmap/vmap_array_ops.py +170 -78
- mindspore/ops/_vmap/vmap_base.py +24 -10
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
- mindspore/ops/_vmap/vmap_image_ops.py +52 -0
- mindspore/ops/_vmap/vmap_math_ops.py +77 -6
- mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
- mindspore/ops/_vmap/vmap_other_ops.py +3 -1
- mindspore/ops/_vmap/vmap_random_ops.py +55 -3
- mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
- mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/__init__.py +1 -4
- mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
- mindspore/ops/composite/__init__.py +12 -13
- mindspore/ops/composite/base.py +261 -254
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +197 -156
- mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
- mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
- mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
- mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
- mindspore/ops/function/__init__.py +323 -8
- mindspore/ops/function/array_func.py +3511 -780
- mindspore/ops/function/clip_func.py +329 -0
- mindspore/ops/function/debug_func.py +6 -6
- mindspore/ops/function/grad/__init__.py +5 -1
- mindspore/ops/function/grad/grad_func.py +736 -65
- mindspore/ops/function/image_func.py +270 -0
- mindspore/ops/function/linalg_func.py +268 -8
- mindspore/ops/function/math_func.py +8032 -3164
- mindspore/ops/function/nn_func.py +5619 -1855
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +11 -10
- mindspore/ops/function/random_func.py +939 -77
- mindspore/ops/function/sparse_func.py +249 -84
- mindspore/ops/function/sparse_unary_func.py +2303 -0
- mindspore/ops/function/spectral_func.py +146 -0
- mindspore/ops/function/vmap_func.py +114 -0
- mindspore/ops/functional.py +182 -254
- mindspore/ops/op_info_register.py +79 -34
- mindspore/ops/operations/__init__.py +210 -118
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +25 -15
- mindspore/ops/operations/_grad_ops.py +447 -322
- mindspore/ops/operations/_inner_ops.py +547 -176
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +29 -27
- mindspore/ops/operations/_ocr_ops.py +11 -11
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_quant_ops.py +186 -101
- mindspore/ops/operations/_rl_inner_ops.py +122 -61
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1047 -0
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +4 -4
- mindspore/ops/operations/array_ops.py +1428 -1226
- mindspore/ops/operations/comm_ops.py +180 -117
- mindspore/ops/operations/control_ops.py +4 -2
- mindspore/ops/operations/custom_ops.py +185 -98
- mindspore/ops/operations/debug_ops.py +92 -54
- mindspore/ops/operations/image_ops.py +406 -211
- mindspore/ops/operations/inner_ops.py +42 -53
- mindspore/ops/operations/linalg_ops.py +32 -29
- mindspore/ops/operations/math_ops.py +2076 -897
- mindspore/ops/operations/nn_ops.py +1282 -1252
- mindspore/ops/operations/other_ops.py +124 -278
- mindspore/ops/operations/random_ops.py +345 -178
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +502 -157
- mindspore/ops/operations/spectral_ops.py +107 -0
- mindspore/ops/primitive.py +192 -15
- mindspore/ops/vm_impl_registry.py +23 -2
- mindspore/parallel/__init__.py +6 -1
- mindspore/parallel/_auto_parallel_context.py +199 -92
- mindspore/parallel/_cell_wrapper.py +4 -2
- mindspore/parallel/_cost_model_context.py +3 -0
- mindspore/parallel/_dp_allreduce_fusion.py +2 -1
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +167 -28
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +9 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
- mindspore/parallel/_utils.py +47 -7
- mindspore/parallel/algo_parameter_config.py +5 -1
- mindspore/parallel/checkpoint_transform.py +329 -0
- mindspore/parallel/shard.py +229 -0
- mindspore/perf_msvcbuildinsights.dll +0 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/util.py +4 -3
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +249 -0
- mindspore/profiler/parser/aicpu_data_parser.py +38 -39
- mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
- mindspore/profiler/parser/base_timeline_generator.py +471 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
- mindspore/profiler/parser/framework_parser.py +42 -16
- mindspore/profiler/parser/hccl_parser.py +158 -158
- mindspore/profiler/parser/hwts_log_parser.py +7 -6
- mindspore/profiler/parser/integrator.py +18 -1579
- mindspore/profiler/parser/minddata_analyzer.py +8 -8
- mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +108 -0
- mindspore/profiler/parser/step_trace_parser.py +1 -1
- mindspore/profiler/profiling.py +396 -194
- mindspore/rewrite/__init__.py +6 -2
- mindspore/rewrite/api/node.py +51 -110
- mindspore/rewrite/api/node_type.py +10 -6
- mindspore/rewrite/api/pattern_engine.py +51 -7
- mindspore/rewrite/api/scoped_value.py +64 -53
- mindspore/rewrite/api/symbol_tree.py +108 -61
- mindspore/rewrite/api/tree_node_helper.py +2 -3
- mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
- mindspore/rewrite/ast_helpers/__init__.py +6 -3
- mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
- mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
- mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
- mindspore/rewrite/ast_transformers/__init__.py +0 -1
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
- mindspore/rewrite/common/__init__.py +2 -0
- mindspore/rewrite/common/event.py +1 -1
- mindspore/rewrite/common/observable.py +1 -1
- mindspore/rewrite/common/observer.py +1 -1
- mindspore/rewrite/common/rewrite_elog.py +35 -0
- mindspore/rewrite/namer.py +2 -2
- mindspore/rewrite/namespace.py +14 -4
- mindspore/rewrite/node.py +161 -13
- mindspore/rewrite/parser.py +0 -1
- mindspore/rewrite/parser_register.py +0 -1
- mindspore/rewrite/parsers/arguments_parser.py +3 -2
- mindspore/rewrite/parsers/assign_parser.py +267 -67
- mindspore/rewrite/parsers/attribute_parser.py +56 -0
- mindspore/rewrite/parsers/class_def_parser.py +191 -108
- mindspore/rewrite/parsers/constant_parser.py +101 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/for_parser.py +28 -15
- mindspore/rewrite/parsers/function_def_parser.py +21 -5
- mindspore/rewrite/parsers/if_parser.py +11 -28
- mindspore/rewrite/parsers/module_parser.py +9 -6
- mindspore/rewrite/parsers/return_parser.py +3 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +322 -109
- mindspore/rewrite/symbol_tree_builder.py +45 -8
- mindspore/rewrite/symbol_tree_dumper.py +0 -1
- mindspore/rewrite/topological_manager.py +1 -2
- mindspore/run_check/_check_version.py +209 -112
- mindspore/run_check/run_check.py +2 -1
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +6 -4
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +321 -50
- mindspore/train/callback/__init__.py +3 -1
- mindspore/train/callback/_backup_and_restore.py +120 -0
- mindspore/train/callback/_callback.py +8 -8
- mindspore/train/callback/_checkpoint.py +12 -9
- mindspore/train/callback/_early_stop.py +13 -7
- mindspore/train/callback/_history.py +8 -8
- mindspore/train/callback/_lambda_callback.py +6 -6
- mindspore/train/callback/_landscape.py +36 -38
- mindspore/train/callback/_loss_monitor.py +12 -6
- mindspore/train/callback/_lr_scheduler_callback.py +2 -4
- mindspore/train/callback/_on_request_exit.py +212 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
- mindspore/train/callback/_summary_collector.py +27 -19
- mindspore/train/callback/_time_monitor.py +13 -7
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +122 -33
- mindspore/train/dataset_helper.py +28 -87
- mindspore/train/loss_scale_manager.py +4 -7
- mindspore/{nn → train}/metrics/__init__.py +20 -20
- mindspore/{nn → train}/metrics/accuracy.py +12 -10
- mindspore/{nn → train}/metrics/auc.py +4 -4
- mindspore/{nn → train}/metrics/bleu_score.py +4 -4
- mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
- mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
- mindspore/{nn → train}/metrics/dice.py +6 -5
- mindspore/{nn → train}/metrics/error.py +7 -5
- mindspore/{nn → train}/metrics/fbeta.py +9 -7
- mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
- mindspore/{nn → train}/metrics/loss.py +4 -3
- mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/metric.py +6 -5
- mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
- mindspore/{nn → train}/metrics/perplexity.py +5 -4
- mindspore/{nn → train}/metrics/precision.py +5 -4
- mindspore/{nn → train}/metrics/recall.py +5 -4
- mindspore/{nn → train}/metrics/roc.py +7 -6
- mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/topk.py +7 -5
- mindspore/train/mind_ir_pb2.py +339 -32
- mindspore/train/model.py +113 -84
- mindspore/train/serialization.py +547 -167
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -12
- mindspore/train/train_thor/convert_utils.py +7 -1
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/train/train_thor/model_thor.py +0 -4
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +901 -660
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -514
- mindspore/compression/quant/qat.py +0 -636
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/libatomic-1.dll +0 -0
- mindspore/libgcc_s_seh-1.dll +0 -0
- mindspore/libgfortran-4.dll +0 -0
- mindspore/libgomp-1.dll +0 -0
- mindspore/libjpeg-62.dll +0 -0
- mindspore/libmindspore.dll +0 -0
- mindspore/libmindspore_common.dll +0 -0
- mindspore/libmindspore_core.dll +0 -0
- mindspore/libmindspore_glog.dll +0 -0
- mindspore/libnnacl.dll +0 -0
- mindspore/libopencv_core452.dll +0 -0
- mindspore/libopencv_imgcodecs452.dll +0 -0
- mindspore/libopencv_imgproc452.dll +0 -0
- mindspore/libquadmath-0.dll +0 -0
- mindspore/libsqlite3.dll +0 -0
- mindspore/libssp-0.dll +0 -0
- mindspore/libstdc++-6.dll +0 -0
- mindspore/libtinyxml2.dll +0 -0
- mindspore/libturbojpeg.dll +0 -0
- mindspore/libwinpthread-1.dll +0 -0
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -138
- mindspore/nn/probability/dpn/vae/vae.py +0 -122
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
- mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
- mindspore/ops/composite/array_ops.py +0 -210
- mindspore/ops/composite/clip_ops.py +0 -238
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/ops/operations/sponge_ops.py +0 -3531
- mindspore/ops/operations/sponge_update_ops.py +0 -2546
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- mindspore/run_check/_check_deps_version.py +0 -84
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2020-
|
|
1
|
+
# Copyright 2020-2023 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -15,24 +15,32 @@
|
|
|
15
15
|
|
|
16
16
|
"""constexpr util"""
|
|
17
17
|
from __future__ import absolute_import
|
|
18
|
+
from enum import IntEnum
|
|
19
|
+
|
|
18
20
|
|
|
19
21
|
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
|
|
20
22
|
from mindspore.ops import functional as F
|
|
21
|
-
from mindspore.ops
|
|
23
|
+
from mindspore.ops import operations as P
|
|
22
24
|
from mindspore.ops.composite import base
|
|
23
25
|
from mindspore.ops._primitive_cache import _get_cache_prim
|
|
24
|
-
from mindspore.ops.operations._inner_ops import TensorCopySlices, SliceGetItem,
|
|
26
|
+
from mindspore.ops.operations._inner_ops import TensorCopySlices, SliceGetItem, \
|
|
27
|
+
TopTypeof, issubclass_, IsParameter, GetitemTensorIndexInfo, SetitemTensorIndexInfo
|
|
25
28
|
from mindspore.common import dtype as mstype
|
|
26
29
|
from mindspore.common._register_for_tensor import tensor_operator_registry
|
|
27
|
-
from mindspore.common.
|
|
28
|
-
from mindspore.common
|
|
30
|
+
from mindspore.common.initializer import Zero
|
|
31
|
+
from mindspore.common import Tensor, CSRTensor, COOTensor
|
|
32
|
+
from mindspore.common import mutable
|
|
33
|
+
from mindspore import ops
|
|
34
|
+
from mindspore.ops.primitive import _primexpr
|
|
29
35
|
|
|
30
36
|
slice_get_item = SliceGetItem()
|
|
31
37
|
hyper_map = base.HyperMap()
|
|
32
38
|
stack = P.Stack(axis=-1)
|
|
33
39
|
copy_slice = TensorCopySlices()
|
|
34
|
-
dynamic_broadcast_to = DynamicBroadcastTo()
|
|
35
40
|
toptypeof = TopTypeof()
|
|
41
|
+
is_parameter = IsParameter()
|
|
42
|
+
getitem_tensor_index_info = GetitemTensorIndexInfo(const_utils.is_ascend())
|
|
43
|
+
setitem_tensor_index_info = SetitemTensorIndexInfo(const_utils.is_ascend())
|
|
36
44
|
|
|
37
45
|
|
|
38
46
|
def strided_slice(data, begin_strides, end_strides, step_strides, begin_mask=0, end_mask=0, ellipsis_mask=0,
|
|
@@ -43,50 +51,138 @@ def strided_slice(data, begin_strides, end_strides, step_strides, begin_mask=0,
|
|
|
43
51
|
return strided_slice_(data, begin_strides, end_strides, step_strides)
|
|
44
52
|
|
|
45
53
|
|
|
54
|
+
class ValueTransferType(IntEnum):
|
|
55
|
+
"""Transfer op types of handling tensor getitem/setitem"""
|
|
56
|
+
kUnknown = 0
|
|
57
|
+
kTensorScatterUpdate = 1
|
|
58
|
+
kExpandDims = 2
|
|
59
|
+
kBroadCast = 3
|
|
60
|
+
kCast = 4
|
|
61
|
+
kSelect = 5
|
|
62
|
+
kGather = 6
|
|
63
|
+
kStrideSlice = 7
|
|
64
|
+
kStrideSliceWithMask = 8
|
|
65
|
+
kGatherND = 9
|
|
66
|
+
kScatterNdUpdate = 10
|
|
67
|
+
kReshape = 11
|
|
68
|
+
kScatterND = 12
|
|
69
|
+
kNumberToTensor = 13
|
|
70
|
+
kHandleSequenceValue = 14
|
|
71
|
+
kByPass = 15
|
|
72
|
+
kReSetItemByIndex = 16
|
|
73
|
+
kCopySlice = 17
|
|
74
|
+
kSetItemByBool = 18
|
|
75
|
+
kEmptyTensor = 19
|
|
76
|
+
kSetItemByEllipsis = 20
|
|
77
|
+
kRaiseIndexError = 21
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def data_update(transfer_types, args, data, new_index, value=None):
|
|
81
|
+
"""
|
|
82
|
+
We finally generate a new tensor when handling tensor getitem/setitem
|
|
83
|
+
by transfer data and value with index.
|
|
84
|
+
"""
|
|
85
|
+
for transfer_type, arg in zip(transfer_types, args):
|
|
86
|
+
if transfer_type == ValueTransferType.kUnknown:
|
|
87
|
+
raise IndexError(f"Inlvaid transfer type {transfer_type}.")
|
|
88
|
+
if transfer_type <= ValueTransferType.kScatterND:
|
|
89
|
+
data = data_update_by_ops(transfer_type, arg, data, new_index, value)
|
|
90
|
+
if transfer_type == ValueTransferType.kSetItemByBool:
|
|
91
|
+
return tensor_setitem_by_bool(data, new_index, value)
|
|
92
|
+
if transfer_type == ValueTransferType.kCopySlice:
|
|
93
|
+
return copy_slice(data, value.astype(data.dtype), arg[0], arg[1], arg[2])
|
|
94
|
+
if transfer_type == ValueTransferType.kSetItemByEllipsis:
|
|
95
|
+
return tensor_setitem_by_ellipsis(data, new_index, value)
|
|
96
|
+
if transfer_type == ValueTransferType.kReSetItemByIndex:
|
|
97
|
+
data[new_index] = value
|
|
98
|
+
return data
|
|
99
|
+
if transfer_type == ValueTransferType.kEmptyTensor:
|
|
100
|
+
return handle_empty_tensor(arg, data)
|
|
101
|
+
if transfer_type == ValueTransferType.kRaiseIndexError:
|
|
102
|
+
raise IndexError(
|
|
103
|
+
f'index {arg[0]} is out of bounds for dimension with size {arg[1]}')
|
|
104
|
+
return data
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def data_update_by_ops(transfer_type, arg, data, new_index, value=None):
|
|
108
|
+
"""
|
|
109
|
+
Generate a new tensor when handling tensor getitem/setitem
|
|
110
|
+
by ops.
|
|
111
|
+
"""
|
|
112
|
+
if transfer_type == ValueTransferType.kStrideSliceWithMask:
|
|
113
|
+
stride_info, mask_index = arg[0], arg[1]
|
|
114
|
+
data = strided_slice(data, stride_info[0], stride_info[1], stride_info[2],
|
|
115
|
+
mask_index[0], mask_index[1], 0, 0, mask_index[2])
|
|
116
|
+
elif transfer_type == ValueTransferType.kGatherND:
|
|
117
|
+
if isinstance(new_index, list):
|
|
118
|
+
new_index = handle_multi_dim_index_tensor(new_index, arg)
|
|
119
|
+
data = F.gather_nd(data, Tensor(new_index))
|
|
120
|
+
elif transfer_type == ValueTransferType.kTensorScatterUpdate:
|
|
121
|
+
if isinstance(new_index, list):
|
|
122
|
+
new_index = handle_multi_dim_index_tensor(new_index, arg)
|
|
123
|
+
data = F.tensor_scatter_update(data, new_index, value)
|
|
124
|
+
elif transfer_type == ValueTransferType.kScatterNdUpdate:
|
|
125
|
+
F.scatter_nd_update(data, new_index, value)
|
|
126
|
+
elif transfer_type == ValueTransferType.kSelect:
|
|
127
|
+
data = F.select(Tensor(new_index), value, data)
|
|
128
|
+
elif transfer_type == ValueTransferType.kReshape:
|
|
129
|
+
data = F.reshape(data, arg)
|
|
130
|
+
elif transfer_type == ValueTransferType.kGather:
|
|
131
|
+
data = F.gather(data, new_index, 0)
|
|
132
|
+
elif transfer_type == ValueTransferType.kExpandDims:
|
|
133
|
+
data = F.expand_dims(data, 0)
|
|
134
|
+
elif transfer_type == ValueTransferType.kStrideSlice:
|
|
135
|
+
data = F.strided_slice(data, arg[0], arg[1], arg[2])
|
|
136
|
+
else:
|
|
137
|
+
raise IndexError(f"Inlvaid transfer type {transfer_type}.")
|
|
138
|
+
return data
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def value_update(transfer_types, args, data, value):
|
|
142
|
+
"""Transfer value before set value to tensor when handling tensor setitem"""
|
|
143
|
+
for transfer_type, arg in zip(transfer_types, args):
|
|
144
|
+
if transfer_type == ValueTransferType.kByPass:
|
|
145
|
+
continue
|
|
146
|
+
if transfer_type == ValueTransferType.kNumberToTensor:
|
|
147
|
+
value = F.fill(F.dtype(data), (), value)
|
|
148
|
+
elif transfer_type == ValueTransferType.kHandleSequenceValue:
|
|
149
|
+
op_type, index = arg
|
|
150
|
+
if op_type == const_utils.SET_ITEM_BY_ONE_TENSOR:
|
|
151
|
+
index = Tensor(index)
|
|
152
|
+
value = _generate_updates_from_sequence(
|
|
153
|
+
data, index, value, op_type)
|
|
154
|
+
elif transfer_type == ValueTransferType.kExpandDims:
|
|
155
|
+
value = F.expand_dims(value, arg)
|
|
156
|
+
elif transfer_type == ValueTransferType.kBroadCast:
|
|
157
|
+
value = _broadcast(arg, value.astype(F.dtype(data)))
|
|
158
|
+
elif transfer_type == ValueTransferType.kCast:
|
|
159
|
+
value = F.cast(value, F.dtype(data))
|
|
160
|
+
elif transfer_type == ValueTransferType.kReshape:
|
|
161
|
+
value = F.reshape(value, arg)
|
|
162
|
+
elif transfer_type == ValueTransferType.kScatterND:
|
|
163
|
+
value = F.scatter_nd(arg[0], value, arg[1])
|
|
164
|
+
else:
|
|
165
|
+
raise IndexError(f"Inlvaid transfer type {transfer_type}.")
|
|
166
|
+
return value
|
|
167
|
+
|
|
168
|
+
|
|
46
169
|
def _tensor_getitem(self, index):
|
|
47
170
|
"""Handle tensor getitem"""
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
return tensor_index_by_list(self, index)
|
|
52
|
-
if isinstance(index, tuple):
|
|
53
|
-
return tensor_index_by_tuple(self, index)
|
|
54
|
-
if isinstance(index, bool):
|
|
55
|
-
return _tensor_index_by_bool(self, index)
|
|
56
|
-
if isinstance(index, int):
|
|
57
|
-
return _tensor_index_by_integer(self, index)
|
|
58
|
-
if isinstance(index, slice):
|
|
59
|
-
return tensor_index_by_slice(self, index)
|
|
60
|
-
if index is None:
|
|
61
|
-
return F.expand_dims(self, 0)
|
|
62
|
-
if index is ...:
|
|
63
|
-
return self
|
|
64
|
-
raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool, tensor with int, "
|
|
65
|
-
f"list and tuple ,but got {index} with type {type(index)}.")
|
|
171
|
+
new_index, tensor_update_types, tensor_update_args = getitem_tensor_index_info(
|
|
172
|
+
self, index)
|
|
173
|
+
return data_update(tensor_update_types, tensor_update_args, self, new_index)
|
|
66
174
|
|
|
67
175
|
|
|
68
176
|
def _tensor_setitem(self, index, value):
|
|
69
177
|
"""Handle tensor setitem"""
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
return tensor_setitem_by_tuple(self, index, value)
|
|
79
|
-
if isinstance(index, bool):
|
|
80
|
-
return tensor_setitem_by_bool(self, index, value)
|
|
81
|
-
if isinstance(index, int):
|
|
82
|
-
return tensor_setitem_by_number(self, index, value)
|
|
83
|
-
if isinstance(index, slice):
|
|
84
|
-
return tensor_setitem_by_slice(self, index, value)
|
|
85
|
-
if index in (None, ...):
|
|
86
|
-
return tensor_setitem_by_ellipsis(self, index, value)
|
|
87
|
-
|
|
88
|
-
raise IndexError("Tensor setitem index only support integers, slices(`:`), ellipsis(`...`), bool, tensor, \
|
|
89
|
-
list and tuple, but got {index} with type{type(index)}")
|
|
178
|
+
setitem_info = setitem_tensor_index_info(self, index, value)
|
|
179
|
+
new_index = setitem_info[0]
|
|
180
|
+
v_transfer_types = setitem_info[1]
|
|
181
|
+
v_transfer_args = setitem_info[2]
|
|
182
|
+
data_update_types = setitem_info[3]
|
|
183
|
+
data_update_args = setitem_info[4]
|
|
184
|
+
value = value_update(v_transfer_types, v_transfer_args, self, value)
|
|
185
|
+
return data_update(data_update_types, data_update_args, self, new_index, value)
|
|
90
186
|
|
|
91
187
|
|
|
92
188
|
tensor_operator_registry.register("__getitem__", _tensor_getitem)
|
|
@@ -119,6 +215,10 @@ def _tensor_mul(self, other):
|
|
|
119
215
|
return F.mul(self, other)
|
|
120
216
|
|
|
121
217
|
|
|
218
|
+
def _tensor_matmul(self, other):
|
|
219
|
+
return F.matmul(self, other)
|
|
220
|
+
|
|
221
|
+
|
|
122
222
|
def _tensor_div(self, other):
|
|
123
223
|
if isinstance(self, (tuple, list)):
|
|
124
224
|
self = sequence_to_tensor(self, F.dtype(other))
|
|
@@ -158,6 +258,7 @@ def _tensor_floordiv(self, other):
|
|
|
158
258
|
tensor_operator_registry.register('__add__', _tensor_add)
|
|
159
259
|
tensor_operator_registry.register('__sub__', _tensor_sub)
|
|
160
260
|
tensor_operator_registry.register('__mul__', _tensor_mul)
|
|
261
|
+
tensor_operator_registry.register('__matmul__', _tensor_matmul)
|
|
161
262
|
tensor_operator_registry.register('__truediv__', _tensor_div)
|
|
162
263
|
tensor_operator_registry.register('__mod__', _tensor_mod)
|
|
163
264
|
tensor_operator_registry.register('__pow__', _tensor_pow)
|
|
@@ -165,6 +266,13 @@ tensor_operator_registry.register('__rpow__', _tensor_rpow)
|
|
|
165
266
|
tensor_operator_registry.register('__floordiv__', _tensor_floordiv)
|
|
166
267
|
|
|
167
268
|
|
|
269
|
+
def _scalar_to_tensor(input_x):
|
|
270
|
+
if ops.isconstant(input_x):
|
|
271
|
+
return P.ScalarToTensor()(input_x, ops.dtype(input_x))
|
|
272
|
+
# use add Tensor([0]) cast scalar to tensor.
|
|
273
|
+
return ops.add(input_x, mutable(Tensor(0)))
|
|
274
|
+
|
|
275
|
+
|
|
168
276
|
def tensor_item(data, *args):
|
|
169
277
|
"""Tensor getitem by index whose dtype is int or tuple with int."""
|
|
170
278
|
# transform a.item(tuple(int)) -> a.item(int1,int2...intN)
|
|
@@ -239,13 +347,9 @@ def tensor_itemset_by_tuple_with_number(data, tuple_index, nubmer_value):
|
|
|
239
347
|
|
|
240
348
|
def _broadcast(broadcast_shape, x):
|
|
241
349
|
"""Broadcast tensor to the required shape."""
|
|
242
|
-
if
|
|
350
|
+
if F.shape(x) == broadcast_shape:
|
|
243
351
|
return x
|
|
244
|
-
|
|
245
|
-
if multiples:
|
|
246
|
-
x = F.reshape(x, const_utils.expanded_shape(F.shape(x), len(multiples) - F.rank(x)))
|
|
247
|
-
return F.tile(x, multiples)
|
|
248
|
-
return x
|
|
352
|
+
return F.broadcast_to(x, broadcast_shape)
|
|
249
353
|
|
|
250
354
|
|
|
251
355
|
def _transform_indexing_tensor(broadcast_shape, final_shape, new_shape, item):
|
|
@@ -285,6 +389,46 @@ def _transform_ellipsis_to_slice(data, tuple_index, op_name):
|
|
|
285
389
|
return tuple_index_new
|
|
286
390
|
|
|
287
391
|
|
|
392
|
+
def handle_empty_tensor(arg, data):
|
|
393
|
+
"""handle data update with empty tensor"""
|
|
394
|
+
if 0 in arg:
|
|
395
|
+
init_func = Zero()
|
|
396
|
+
init_func.__enable_zero_dim__ = True
|
|
397
|
+
return Tensor(shape=arg, dtype=data.dtype, init=init_func)
|
|
398
|
+
return const_utils.make_tensor([], data.dtype, arg)
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def handle_multi_dim_index_tensor(new_index, arg):
|
|
402
|
+
"""handle data update with multi dim index tensor"""
|
|
403
|
+
slice_cnt = 0
|
|
404
|
+
new_indies_tensor = []
|
|
405
|
+
if len(arg) == 1:
|
|
406
|
+
broadcast_shape = arg[0]
|
|
407
|
+
new_index = hyper_map(F.partial(Tensor), new_index)
|
|
408
|
+
broadcast_tensors = hyper_map(
|
|
409
|
+
F.partial(_broadcast, broadcast_shape), new_index)
|
|
410
|
+
new_broadcast_tensors = ()
|
|
411
|
+
for tensor in broadcast_tensors:
|
|
412
|
+
new_broadcast_tensors += (F.cast(tensor, mstype.int64),)
|
|
413
|
+
new_index = stack(new_broadcast_tensors)
|
|
414
|
+
return new_index
|
|
415
|
+
broadcast_shape, final_shape, index_tensor_new_shape, slice_shapes, tensor_positions, fancy_position = arg
|
|
416
|
+
for i, index in enumerate(new_index):
|
|
417
|
+
if i in tensor_positions:
|
|
418
|
+
transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape,
|
|
419
|
+
Tensor(index))
|
|
420
|
+
new_indies_tensor.append(F.cast(transform_tensor, mstype.int64))
|
|
421
|
+
else:
|
|
422
|
+
shape = const_utils.compute_slice_shape(
|
|
423
|
+
slice_shapes, len(broadcast_shape), slice_cnt, fancy_position)
|
|
424
|
+
array = Tensor(index).reshape(shape)
|
|
425
|
+
slice_index_tensor = _broadcast(final_shape, array)
|
|
426
|
+
new_indies_tensor.append(F.cast(slice_index_tensor, mstype.int64))
|
|
427
|
+
slice_cnt += 1
|
|
428
|
+
new_index = stack(new_indies_tensor)
|
|
429
|
+
return new_index
|
|
430
|
+
|
|
431
|
+
|
|
288
432
|
def _expand_data_dims(data, tuple_index):
|
|
289
433
|
"""expand the data's dim with 'None' and 'Boolean' in tuple_index"""
|
|
290
434
|
indexes_types = hyper_map(toptypeof, tuple_index)
|
|
@@ -307,12 +451,34 @@ def _expand_data_dims(data, tuple_index):
|
|
|
307
451
|
return data, tuple_index_new
|
|
308
452
|
|
|
309
453
|
|
|
454
|
+
def convert_variable_to_tensor_slice(slice_index):
|
|
455
|
+
"""convert mutable scalar to tensor"""
|
|
456
|
+
start = slice_get_item(slice_index, "start")
|
|
457
|
+
stop = slice_get_item(slice_index, "stop")
|
|
458
|
+
step = slice_get_item(slice_index, "step")
|
|
459
|
+
find_mutable_scalar = False
|
|
460
|
+
if isinstance(start, int) and not F.isconstant(start):
|
|
461
|
+
start = ops.Cast()(start, mstype.int64)
|
|
462
|
+
find_mutable_scalar = True
|
|
463
|
+
if isinstance(stop, int) and not F.isconstant(stop):
|
|
464
|
+
stop = ops.Cast()(stop, mstype.int64)
|
|
465
|
+
find_mutable_scalar = True
|
|
466
|
+
if isinstance(step, int) and not F.isconstant(step):
|
|
467
|
+
step = ops.Cast()(step, mstype.int64)
|
|
468
|
+
find_mutable_scalar = True
|
|
469
|
+
if find_mutable_scalar:
|
|
470
|
+
return F.make_slice(start, stop, step)
|
|
471
|
+
return slice_index
|
|
472
|
+
|
|
473
|
+
|
|
310
474
|
def tensor_index_by_slice(data, slice_index):
|
|
311
475
|
"""Tensor getitem by a slice."""
|
|
312
476
|
min_data_dim, max_data_dim = 1, 8
|
|
313
477
|
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
314
478
|
data_shape = F.shape(data)
|
|
315
|
-
|
|
479
|
+
slice_index = convert_variable_to_tensor_slice(slice_index)
|
|
480
|
+
|
|
481
|
+
is_dynamic = (F.is_sequence_value_unknown(data_shape)
|
|
316
482
|
or isinstance(slice_get_item(slice_index, "start"), Tensor)
|
|
317
483
|
or isinstance(slice_get_item(slice_index, "stop"), Tensor)
|
|
318
484
|
or isinstance(slice_get_item(slice_index, "step"), Tensor))
|
|
@@ -335,6 +501,12 @@ def get_stride_info_from_slice(data, slice_index):
|
|
|
335
501
|
data_shape = F.dyn_shape(data)
|
|
336
502
|
begin_strides, end_strides, step_strides = [], [], []
|
|
337
503
|
start, stop, step = get_slice_stride(slice_index, data_shape[0])
|
|
504
|
+
if start.ndim > 0:
|
|
505
|
+
start = start.item()
|
|
506
|
+
if stop.ndim > 0:
|
|
507
|
+
stop = stop.item()
|
|
508
|
+
if step.ndim > 0:
|
|
509
|
+
step = step.item()
|
|
338
510
|
begin_strides.append(start)
|
|
339
511
|
end_strides.append(stop)
|
|
340
512
|
step_strides.append(step)
|
|
@@ -364,19 +536,10 @@ def _tensor_index_by_bool(data, bool_value):
|
|
|
364
536
|
return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.")
|
|
365
537
|
|
|
366
538
|
|
|
367
|
-
def check_range(x, dim_size):
|
|
368
|
-
"""Check whether x is within the range of dim_size"""
|
|
369
|
-
tensor_x = const_utils.make_tensor(x)
|
|
370
|
-
if tensor_x >= dim_size or tensor_x < -dim_size:
|
|
371
|
-
return tensor_x
|
|
372
|
-
tensor_x = tensor_x % dim_size
|
|
373
|
-
return tensor_x
|
|
374
|
-
|
|
375
|
-
|
|
376
539
|
def get_stride_info_from_integer(tensor_int):
|
|
377
540
|
"""Convert integer to slice"""
|
|
378
541
|
begin_strides = [tensor_int]
|
|
379
|
-
end_strides = [tensor_int +
|
|
542
|
+
end_strides = [tensor_int + 1]
|
|
380
543
|
step_strides = [const_utils.make_tensor(1)]
|
|
381
544
|
begin_tensor = stack(begin_strides)
|
|
382
545
|
end_tensor = stack(end_strides)
|
|
@@ -386,14 +549,15 @@ def get_stride_info_from_integer(tensor_int):
|
|
|
386
549
|
|
|
387
550
|
def _tensor_index_by_integer(data, int_index):
|
|
388
551
|
"""Tensor getitem by a single integer number"""
|
|
552
|
+
data_shape = F.shape(data)
|
|
553
|
+
if not data_shape:
|
|
554
|
+
const_utils.raise_type_error("Cannot iterate over a scalar tensor.")
|
|
389
555
|
if data.ndim < 1 or data.ndim > 8:
|
|
390
556
|
const_utils.raise_value_error("Expect Tensor to have dimension between 1 and 8.")
|
|
391
557
|
|
|
392
|
-
data_shape
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
transformed_tensor = check_range(int_index, data_shape[0])
|
|
396
|
-
begin_strides, end_strides, step_strides = get_stride_info_from_integer(transformed_tensor)
|
|
558
|
+
if F.is_sequence_value_unknown(data_shape) or not F.isconstant(int_index):
|
|
559
|
+
tensor_index = _scalar_to_tensor(int_index)
|
|
560
|
+
begin_strides, end_strides, step_strides = get_stride_info_from_integer(tensor_index)
|
|
397
561
|
else:
|
|
398
562
|
transformed_number = const_utils.check_range(int_index, data_shape[0])
|
|
399
563
|
begin_strides, end_strides, step_strides = \
|
|
@@ -401,22 +565,41 @@ def _tensor_index_by_integer(data, int_index):
|
|
|
401
565
|
shrink_axis_mask = 1
|
|
402
566
|
begin_mask = 0
|
|
403
567
|
end_mask = 0
|
|
404
|
-
for i in range(
|
|
568
|
+
for i in range(2, 8):
|
|
405
569
|
begin_mask += 2 ** i
|
|
406
570
|
end_mask += 2 ** i
|
|
407
571
|
return strided_slice(data, begin_strides, end_strides, step_strides, begin_mask, end_mask, 0, 0, shrink_axis_mask)
|
|
408
572
|
|
|
409
573
|
|
|
574
|
+
def _check_dim_shape_valid(data, tensor_index):
|
|
575
|
+
"""check dim and shape of tensor_index for tensor(bool) indexing"""
|
|
576
|
+
if data.ndim < tensor_index.ndim:
|
|
577
|
+
raise IndexError(f"The dim of index cannot be greater than indexed data, but got "
|
|
578
|
+
f"dim of index:{tensor_index.ndim}, dim of data:{data.ndim}")
|
|
579
|
+
if data.shape[:tensor_index.ndim] != tensor_index.shape[:]:
|
|
580
|
+
raise IndexError(f"The shape of index {tensor_index.shape} does not match the shape "
|
|
581
|
+
f"of the indexed data {data.shape}")
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
def tensor_index_by_bool_tensor(data, tensor_index):
|
|
585
|
+
"""Tensor getitem by a bool tensor"""
|
|
586
|
+
_check_dim_shape_valid(data, tensor_index)
|
|
587
|
+
tensor_index = tensor_index.nonzero()
|
|
588
|
+
return F.gather_nd(data, tensor_index)
|
|
589
|
+
|
|
590
|
+
|
|
410
591
|
def tensor_index_by_tensor(data, tensor_index):
|
|
411
592
|
"""Tensor getitem by a single tensor"""
|
|
412
593
|
min_data_dim, max_data_dim = 0, 7
|
|
413
594
|
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
595
|
+
if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Int):
|
|
596
|
+
return F.gather(data, tensor_index, 0)
|
|
597
|
+
if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Bool):
|
|
598
|
+
return tensor_index_by_bool_tensor(data, tensor_index)
|
|
599
|
+
exp_msg = const_utils.gen_exception_msg(
|
|
600
|
+
"The tensor index must be int or bool type, but got {}.", F.dtype(tensor_index))
|
|
601
|
+
const_utils.raise_index_error(exp_msg)
|
|
602
|
+
return data
|
|
420
603
|
|
|
421
604
|
|
|
422
605
|
def tensor_index_by_list(data, list_index):
|
|
@@ -427,10 +610,13 @@ def tensor_index_by_list(data, list_index):
|
|
|
427
610
|
data_shape = F.shape(data)
|
|
428
611
|
indexes_types = hyper_map(toptypeof, list_index)
|
|
429
612
|
if const_utils.check_type_isinstance(indexes_types, (mstype.Bool, mstype.Int)):
|
|
430
|
-
if data_shape[0]
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
613
|
+
if not F.isconstant(data_shape[0]):
|
|
614
|
+
if all(isinstance(i, bool) for i in list_index):
|
|
615
|
+
const_utils.raise_unimplemented_error(
|
|
616
|
+
"Not supported to the dynamic shape tensor slice by using list of Boolean type")
|
|
617
|
+
tensor_index = const_utils.sequence_to_index(list_index, None)
|
|
618
|
+
else:
|
|
619
|
+
tensor_index = const_utils.sequence_to_index(list_index, data_shape[0])
|
|
434
620
|
if tensor_index is False:
|
|
435
621
|
const_utils.raise_index_error("When tensor is indexed by list, the list can't be empty.")
|
|
436
622
|
return F.gather(data, tensor_index, 0)
|
|
@@ -441,18 +627,28 @@ def tensor_index_by_list(data, list_index):
|
|
|
441
627
|
return tensor_index_by_tuple(data, tuple_index_new)
|
|
442
628
|
|
|
443
629
|
|
|
630
|
+
def convert_tupleslice_to_tensor(tuple_index):
|
|
631
|
+
"""convert mutable scalar in slice to tensor"""
|
|
632
|
+
new_tuple_index = []
|
|
633
|
+
for item in tuple_index:
|
|
634
|
+
if isinstance(item, slice):
|
|
635
|
+
item = convert_variable_to_tensor_slice(item)
|
|
636
|
+
new_tuple_index.append(item)
|
|
637
|
+
return tuple(new_tuple_index)
|
|
638
|
+
|
|
639
|
+
|
|
444
640
|
def tensor_index_by_tuple(data, tuple_index):
|
|
445
641
|
"""Tensor getitem by tuple of various types with None"""
|
|
446
642
|
if not tuple_index:
|
|
447
643
|
return data
|
|
448
644
|
|
|
645
|
+
tuple_index = convert_tupleslice_to_tensor(tuple_index)
|
|
449
646
|
op_name = const_utils.TENSOR_GETITEM
|
|
450
647
|
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
|
451
648
|
data, tuple_index = _expand_data_dims(data, tuple_index)
|
|
452
649
|
|
|
453
650
|
min_data_dim, max_data_dim = 1, 8
|
|
454
651
|
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
455
|
-
|
|
456
652
|
indexes_types = hyper_map(toptypeof, tuple_index)
|
|
457
653
|
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
|
|
458
654
|
if contain_type == const_utils.ALL_BASIC:
|
|
@@ -460,31 +656,6 @@ def tensor_index_by_tuple(data, tuple_index):
|
|
|
460
656
|
return _tensor_getitem_by_tuple(data, tuple_index, op_name)
|
|
461
657
|
|
|
462
658
|
|
|
463
|
-
def _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name):
|
|
464
|
-
"""Tensor getitem by a tuple of tensor."""
|
|
465
|
-
data_shape = F.shape(data)
|
|
466
|
-
tuple_index_len = len(tuple_index)
|
|
467
|
-
|
|
468
|
-
indexes_types = hyper_map(F.dtype, tuple_index)
|
|
469
|
-
const_utils.check_indexes_types_valid(indexes_types, mstype.int_type, op_name)
|
|
470
|
-
tensor_index_shape = hyper_map(F.shape, tuple_index)
|
|
471
|
-
broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name)
|
|
472
|
-
if 0 in broadcast_shape:
|
|
473
|
-
res_shape = broadcast_shape
|
|
474
|
-
if tuple_index_len < len(data_shape):
|
|
475
|
-
res_shape += data_shape[tuple_index_len:]
|
|
476
|
-
res = const_utils.make_tensor([], data.dtype, res_shape)
|
|
477
|
-
return res
|
|
478
|
-
|
|
479
|
-
broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index)
|
|
480
|
-
new_broadcast_tensors = ()
|
|
481
|
-
for tensor in broadcast_tensors:
|
|
482
|
-
new_broadcast_tensors += (F.cast(tensor, mstype.int64),)
|
|
483
|
-
indices = stack(new_broadcast_tensors)
|
|
484
|
-
result = F.gather_nd(data, indices)
|
|
485
|
-
return result
|
|
486
|
-
|
|
487
|
-
|
|
488
659
|
def get_slice_stride(slice_index, dim_size):
|
|
489
660
|
"""Get slice stride info"""
|
|
490
661
|
start = slice_get_item(slice_index, "start")
|
|
@@ -498,13 +669,13 @@ def get_slice_stride(slice_index, dim_size):
|
|
|
498
669
|
if step is None:
|
|
499
670
|
step = const_utils.make_tensor(1)
|
|
500
671
|
|
|
501
|
-
if
|
|
672
|
+
if issubclass_(F.typeof(start), mstype.number):
|
|
502
673
|
start = const_utils.make_tensor(start)
|
|
503
674
|
|
|
504
|
-
if
|
|
675
|
+
if issubclass_(F.typeof(stop), mstype.number):
|
|
505
676
|
stop = const_utils.make_tensor(stop)
|
|
506
677
|
|
|
507
|
-
if
|
|
678
|
+
if issubclass_(F.typeof(step), mstype.number):
|
|
508
679
|
step = const_utils.make_tensor(step)
|
|
509
680
|
|
|
510
681
|
return start, stop, step
|
|
@@ -543,7 +714,7 @@ def _get_stride_info_from_tuple(data, tuple_index):
|
|
|
543
714
|
step_strides.append(step)
|
|
544
715
|
index_count = index_count + 1
|
|
545
716
|
elif isinstance(index, int):
|
|
546
|
-
int_tensor =
|
|
717
|
+
int_tensor = _scalar_to_tensor(index)
|
|
547
718
|
begin_strides.append(int_tensor)
|
|
548
719
|
end_strides.append(int_tensor + const_utils.make_tensor(1))
|
|
549
720
|
step_strides.append(const_utils.make_tensor(1))
|
|
@@ -577,7 +748,7 @@ def _get_stride_info_from_tuple(data, tuple_index):
|
|
|
577
748
|
def _tensor_getitem_by_tuple_slice(data, tuple_index):
|
|
578
749
|
"""Tensor getitem by a tuple of slice"""
|
|
579
750
|
data_shape = F.shape(data)
|
|
580
|
-
is_dynamic =
|
|
751
|
+
is_dynamic = F.is_sequence_value_unknown(data_shape)
|
|
581
752
|
for item in tuple_index:
|
|
582
753
|
if isinstance(item, slice):
|
|
583
754
|
is_dynamic = is_dynamic or isinstance(slice_get_item(item, "start"), Tensor) \
|
|
@@ -599,6 +770,39 @@ def _tensor_getitem_by_tuple_slice(data, tuple_index):
|
|
|
599
770
|
return strided_slice(data, begin_v, end_v, step_v, begin_mask, end_mask, 0, 0, shrink_axis_mask)
|
|
600
771
|
|
|
601
772
|
|
|
773
|
+
@_primexpr
|
|
774
|
+
def _tensor_getitem_by_tuple_parse_bool_tensor_index(index, tuple_index_new, tensor_indexes,
|
|
775
|
+
tensor_positions_new):
|
|
776
|
+
""" parse index of bool tensor type """
|
|
777
|
+
indices = index.nonzero()
|
|
778
|
+
if indices.shape[0] == 0:
|
|
779
|
+
return None, tensor_indexes, tensor_positions_new
|
|
780
|
+
indices = F.cast(indices, mstype.int64)
|
|
781
|
+
indices = indices.T
|
|
782
|
+
for sub_index in indices:
|
|
783
|
+
tensor_positions_new.append(len(tuple_index_new))
|
|
784
|
+
tuple_index_new += (sub_index,)
|
|
785
|
+
tensor_indexes.append(sub_index)
|
|
786
|
+
return tuple_index_new, tensor_indexes, tensor_positions_new
|
|
787
|
+
|
|
788
|
+
|
|
789
|
+
def _tensor_getitem_by_tuple_parse_tensor_index(index, tuple_index_new, tensor_indexes, tensor_positions_new):
|
|
790
|
+
""" parse index of tensor type """
|
|
791
|
+
if F.dtype(index) in mstype.int_type:
|
|
792
|
+
tensor_index = F.cast(index, mstype.int64)
|
|
793
|
+
tensor_positions_new.append(len(tuple_index_new))
|
|
794
|
+
tuple_index_new += (tensor_index,)
|
|
795
|
+
tensor_indexes.append(tensor_index)
|
|
796
|
+
elif F.dtype(index) == mstype.bool_:
|
|
797
|
+
return _tensor_getitem_by_tuple_parse_bool_tensor_index(index, tuple_index_new, tensor_indexes,
|
|
798
|
+
tensor_positions_new)
|
|
799
|
+
else:
|
|
800
|
+
exp_msg = const_utils.gen_exception_msg(
|
|
801
|
+
"The tensor element in tuple index must be int or bool type, but got {}.", F.dtype(index))
|
|
802
|
+
const_utils.raise_index_error(exp_msg)
|
|
803
|
+
return tuple_index_new, tensor_indexes, tensor_positions_new
|
|
804
|
+
|
|
805
|
+
|
|
602
806
|
def _tensor_getitem_by_tuple(data, tuple_index, op_name):
|
|
603
807
|
"""Tensor getitem by a tuple of mixed tensor."""
|
|
604
808
|
slice_is_tensor = False
|
|
@@ -609,51 +813,49 @@ def _tensor_getitem_by_tuple(data, tuple_index, op_name):
|
|
|
609
813
|
or isinstance(slice_get_item(item, "step"), Tensor)
|
|
610
814
|
if slice_is_tensor:
|
|
611
815
|
const_utils.raise_index_error("Not supported when slice has tensor")
|
|
612
|
-
|
|
613
|
-
tensor_indexes, slice_indexes = [], []
|
|
816
|
+
|
|
614
817
|
indexes_types = hyper_map(toptypeof, tuple_index)
|
|
615
818
|
slice_positions, _, _, int_positions, _, tensor_positions, sequence_positions = \
|
|
616
819
|
const_utils.get_pos_of_indexes_types(indexes_types, op_name)
|
|
617
|
-
tuple_index_new, slice_shapes = (), ()
|
|
618
820
|
data_shape = F.shape(data)
|
|
821
|
+
tensor_indexes, slice_indexes = [], []
|
|
822
|
+
tuple_index_new, slice_shapes = (), ()
|
|
823
|
+
slice_positions_new, tensor_positions_new = [], []
|
|
619
824
|
for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
|
|
620
825
|
if i in int_positions:
|
|
621
826
|
int_index = const_utils.check_range(index, dim_size)
|
|
622
827
|
tensor_index = F.scalar_to_tensor(int_index, mstype.int64)
|
|
623
|
-
if
|
|
624
|
-
|
|
625
|
-
tensor_index = check_range(index, dyn_shape[i])
|
|
828
|
+
if F.is_sequence_value_unknown(data_shape):
|
|
829
|
+
tensor_index = _scalar_to_tensor(int_index)
|
|
626
830
|
tensor_index = F.cast(tensor_index, mstype.int64)
|
|
831
|
+
tensor_positions_new.append(len(tuple_index_new))
|
|
627
832
|
tuple_index_new += (tensor_index,)
|
|
628
833
|
tensor_indexes.append(tensor_index)
|
|
629
|
-
tensor_positions += (i,)
|
|
630
834
|
elif i in sequence_positions:
|
|
631
835
|
tensor_index = const_utils.sequence_to_index(index, dim_size)
|
|
632
836
|
if tensor_index is False:
|
|
633
837
|
const_utils.raise_index_error("The sequence element(tuple/list) in tuple index can't be empty.")
|
|
838
|
+
tensor_positions_new.append(len(tuple_index_new))
|
|
634
839
|
tuple_index_new += (tensor_index,)
|
|
635
840
|
tensor_indexes.append(tensor_index)
|
|
636
|
-
tensor_positions += (i,)
|
|
637
841
|
elif i in tensor_positions:
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
tensor_index = F.cast(index, mstype.int64)
|
|
644
|
-
tuple_index_new += (tensor_index,)
|
|
645
|
-
tensor_indexes.append(tensor_index)
|
|
842
|
+
tuple_index_new, tensor_indexes, tensor_positions_new = \
|
|
843
|
+
_tensor_getitem_by_tuple_parse_tensor_index(index, tuple_index_new,
|
|
844
|
+
tensor_indexes, tensor_positions_new)
|
|
845
|
+
if tuple_index_new is None:
|
|
846
|
+
return Tensor([])
|
|
646
847
|
elif i in slice_positions:
|
|
647
848
|
slice_ele_list_index = const_utils.transform_slice_to_ele_list(index, dim_size)
|
|
648
849
|
slice_shapes += (len(slice_ele_list_index),)
|
|
850
|
+
slice_positions_new.append(len(tuple_index_new))
|
|
649
851
|
tuple_index_new += (slice_ele_list_index,)
|
|
650
852
|
slice_indexes.append(slice_ele_list_index)
|
|
651
|
-
|
|
652
853
|
tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
|
|
653
854
|
broadcast_shape, index_tensor_new_shape, final_shape, fancy_position = \
|
|
654
|
-
const_utils.generate_index_info_from_tuple_of_mixed_tensors(
|
|
855
|
+
const_utils.generate_index_info_from_tuple_of_mixed_tensors(tensor_positions_new, tensor_indexes_shapes,
|
|
655
856
|
slice_shapes, op_name)
|
|
656
857
|
|
|
858
|
+
tuple_index_len = len(tuple_index)
|
|
657
859
|
if 0 in final_shape + data_shape:
|
|
658
860
|
if tuple_index_len < len(data_shape):
|
|
659
861
|
final_shape = final_shape + data_shape[tuple_index_len:]
|
|
@@ -662,11 +864,11 @@ def _tensor_getitem_by_tuple(data, tuple_index, op_name):
|
|
|
662
864
|
final_index_tensors = []
|
|
663
865
|
slice_cnt = 0
|
|
664
866
|
for i, index in enumerate(tuple_index_new):
|
|
665
|
-
if i in
|
|
867
|
+
if i in tensor_positions_new:
|
|
666
868
|
transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape,
|
|
667
869
|
index)
|
|
668
870
|
final_index_tensors.append(transform_tensor)
|
|
669
|
-
elif i in
|
|
871
|
+
elif i in slice_positions_new:
|
|
670
872
|
slice_index_tensor = convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape,
|
|
671
873
|
slice_shapes, fancy_position)
|
|
672
874
|
final_index_tensors.append(slice_index_tensor)
|
|
@@ -701,7 +903,6 @@ def _generate_indices_from_tuple(data, tuple_index, op_name, fancy_position):
|
|
|
701
903
|
slice_positions, _, _, int_positions, _, tensor_positions, sequence_positions = \
|
|
702
904
|
const_utils.get_pos_of_indexes_types(indexes_types, op_name)
|
|
703
905
|
tuple_index_new, slice_shapes = (), ()
|
|
704
|
-
|
|
705
906
|
for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
|
|
706
907
|
if i in int_positions:
|
|
707
908
|
int_index = const_utils.check_range(index, dim_size)
|
|
@@ -718,7 +919,7 @@ def _generate_indices_from_tuple(data, tuple_index, op_name, fancy_position):
|
|
|
718
919
|
invalid = const_utils.check_type_invalid(F.dtype(index), mstype.int_type)
|
|
719
920
|
if invalid:
|
|
720
921
|
exp_msg = const_utils.gen_exception_msg(
|
|
721
|
-
"The tensor element in tuple index must be int type, but got {}.", F.dtype(index))
|
|
922
|
+
"The tensor element in tuple index must be int or bool type, but got {}.", F.dtype(index))
|
|
722
923
|
const_utils.raise_index_error(exp_msg)
|
|
723
924
|
tensor_index = F.cast(index, mstype.int64)
|
|
724
925
|
tuple_index_new += (tensor_index,)
|
|
@@ -783,11 +984,11 @@ def _generate_updates_from_sequence(data, index, value, op_type):
|
|
|
783
984
|
def _generate_updates_from_tensor(data, index, value, op_type):
|
|
784
985
|
"""Generate an updates tensor from a tensor."""
|
|
785
986
|
value = value.astype(data.dtype)
|
|
786
|
-
if
|
|
987
|
+
if F.is_sequence_value_unknown(F.shape(data)):
|
|
787
988
|
data_shape = F.dyn_shape(data)
|
|
788
989
|
index_shape = F.dyn_shape(index)
|
|
789
990
|
updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type, True)
|
|
790
|
-
updates =
|
|
991
|
+
updates = ops.broadcast_to(value, updates_shape)
|
|
791
992
|
return updates
|
|
792
993
|
updates_shape = const_utils.generate_updates_shape(data.shape, index.shape, op_type, False)
|
|
793
994
|
need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value.shape)
|
|
@@ -807,6 +1008,7 @@ def tensor_setitem_by_tensor(self, index, value):
|
|
|
807
1008
|
|
|
808
1009
|
|
|
809
1010
|
def tensor_setitem_by_tuple(self, index, value):
|
|
1011
|
+
index = convert_tupleslice_to_tensor(index)
|
|
810
1012
|
if isinstance(value, (int, float, bool)):
|
|
811
1013
|
index = format_tuple_indices(index)
|
|
812
1014
|
return tensor_setitem_by_tuple_with_number(self, index, value)
|
|
@@ -824,6 +1026,7 @@ def tensor_setitem_by_number(self, index, value):
|
|
|
824
1026
|
|
|
825
1027
|
|
|
826
1028
|
def tensor_setitem_by_slice(self, index, value):
|
|
1029
|
+
index = convert_variable_to_tensor_slice(index)
|
|
827
1030
|
if isinstance(value, (int, float, bool)):
|
|
828
1031
|
return tensor_setitem_by_slice_with_number(self, index, value)
|
|
829
1032
|
if isinstance(value, Tensor):
|
|
@@ -844,28 +1047,29 @@ def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
|
|
|
844
1047
|
if F.rank(index) == 0:
|
|
845
1048
|
index = F.expand_dims(index, -1)
|
|
846
1049
|
updates = _generate_updates_from_tensor(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
|
|
847
|
-
|
|
1050
|
+
data_shape = F.shape(data)
|
|
1051
|
+
first_val = data_shape[0]
|
|
1052
|
+
if not F.isconstant(first_val):
|
|
1053
|
+
first_val = -1
|
|
1054
|
+
index = F.select(index < 0, index + first_val, index)
|
|
848
1055
|
index = F.expand_dims(index, -1)
|
|
849
1056
|
if F.rank(index) < 2:
|
|
850
1057
|
index = F.expand_dims(index, 0)
|
|
851
1058
|
updates = F.expand_dims(updates, 0)
|
|
1059
|
+
if is_parameter(data):
|
|
1060
|
+
F.scatter_nd_update(data, index, updates)
|
|
1061
|
+
return data
|
|
852
1062
|
return F.tensor_scatter_update(data, index, updates)
|
|
853
1063
|
|
|
854
1064
|
|
|
855
1065
|
def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
|
|
856
1066
|
"""Set a tensor item by a bool tensor with a tensor."""
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
"When assign value is a tensor, its size should be {}, but current size is {}.")
|
|
864
|
-
dtype = F.dtype(data)
|
|
865
|
-
u_cast = F.cast(value, dtype)
|
|
866
|
-
one_data = F.ones_like(data)
|
|
867
|
-
u = F.tensor_mul(one_data, u_cast)
|
|
868
|
-
result = F.select(index, u, data)
|
|
1067
|
+
index = index.reshape(const_utils.generate_padding_shape(index.shape, len(data.shape)))
|
|
1068
|
+
index = F.broadcast_to(index, data.shape)
|
|
1069
|
+
value = F.cast(value, F.dtype(data))
|
|
1070
|
+
value = value.reshape(const_utils.generate_padding_shape(value.shape, len(data.shape)))
|
|
1071
|
+
value = F.broadcast_to(value, data.shape)
|
|
1072
|
+
result = F.select(index, value, data)
|
|
869
1073
|
return result
|
|
870
1074
|
|
|
871
1075
|
|
|
@@ -876,7 +1080,7 @@ def tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
|
|
|
876
1080
|
if tensor_dtype == const_utils.INT_:
|
|
877
1081
|
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
|
|
878
1082
|
|
|
879
|
-
if
|
|
1083
|
+
if F.is_sequence_value_unknown(F.shape(data)):
|
|
880
1084
|
const_utils.raise_unimplemented_error(
|
|
881
1085
|
"Not supported to the dynamic shape tensor slice by using tensor of Boolean type")
|
|
882
1086
|
return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
|
|
@@ -890,11 +1094,13 @@ def tensor_setitem_by_tensor_with_number(data, index, value):
|
|
|
890
1094
|
def tensor_setitem_by_tensor_with_sequence(data, index, value):
|
|
891
1095
|
"""Assigns the tensor by tensor with tuple value."""
|
|
892
1096
|
index_dtype = F.dtype(index)
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
1097
|
+
if index_dtype in (mstype.int32, mstype.int64):
|
|
1098
|
+
return _tensor_setitem_by_tensor_with_sequence(data, index, value)
|
|
1099
|
+
if index_dtype == mstype.bool_:
|
|
1100
|
+
return _tensor_setitem_by_bool_tensor_with_sequence(data, index, value)
|
|
1101
|
+
exp_msg = const_utils.gen_exception_msg("The tensor index must be int or bool type, but got {}.", index_dtype)
|
|
1102
|
+
const_utils.raise_index_error(exp_msg)
|
|
1103
|
+
return None
|
|
898
1104
|
|
|
899
1105
|
|
|
900
1106
|
def _tensor_setitem_by_tensor_with_sequence(data, index, value):
|
|
@@ -904,6 +1110,12 @@ def _tensor_setitem_by_tensor_with_sequence(data, index, value):
|
|
|
904
1110
|
return F.tensor_scatter_update(data, index, updates)
|
|
905
1111
|
|
|
906
1112
|
|
|
1113
|
+
def _tensor_setitem_by_bool_tensor_with_sequence(data, index, value):
|
|
1114
|
+
"""Set a tensor item by a bool tensor with a tuple."""
|
|
1115
|
+
value = sequence_to_tensor(value, F.dtype(data))
|
|
1116
|
+
return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value)
|
|
1117
|
+
|
|
1118
|
+
|
|
907
1119
|
def tensor_setitem_by_slice_with_number(data, input_slice, value):
|
|
908
1120
|
"""Givens a scalar assign to tensor by slice"""
|
|
909
1121
|
value = F.fill(F.dtype(data), (), value)
|
|
@@ -929,7 +1141,7 @@ def tensor_copy_slice_from_slice(data, input_slice, value):
|
|
|
929
1141
|
if dim0_size >= data_shape[0]:
|
|
930
1142
|
dim0_size = data_shape[0:1]
|
|
931
1143
|
value_shape = P.Concat(-1)((dim0_size, data_shape[1:]))
|
|
932
|
-
value =
|
|
1144
|
+
value = ops.broadcast_to(value, value_shape)
|
|
933
1145
|
return copy_slice(data, value.astype(data.dtype), start_tensor, stop_tensor, step_tensor)
|
|
934
1146
|
|
|
935
1147
|
|
|
@@ -941,7 +1153,7 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
|
|
|
941
1153
|
data_shape = F.shape(data)
|
|
942
1154
|
step = const_utils.get_step_from_slice(input_slice)
|
|
943
1155
|
if step == 1 and not const_utils.is_ascend():
|
|
944
|
-
if
|
|
1156
|
+
if F.is_sequence_value_unknown(data_shape):
|
|
945
1157
|
return tensor_copy_slice_from_slice(data, input_slice, value)
|
|
946
1158
|
start, stop, step = const_utils.normalize_slice(input_slice, data.shape[0])
|
|
947
1159
|
dim0_size = stop - start
|
|
@@ -950,7 +1162,7 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
|
|
|
950
1162
|
value_shape = (dim0_size,) + const_utils.tuple_slice(data.shape, 1, None)
|
|
951
1163
|
value = _broadcast(value_shape, value)
|
|
952
1164
|
return copy_slice(data, value.astype(data.dtype), (start,), (stop,), (step,))
|
|
953
|
-
if
|
|
1165
|
+
if F.is_sequence_value_unknown(data_shape):
|
|
954
1166
|
const_utils.raise_unimplemented_error(
|
|
955
1167
|
"Not supported to take the subscript of dynamic shape tensor slice setitem")
|
|
956
1168
|
indices = const_utils.slice2indices(input_slice, data_shape)
|
|
@@ -974,7 +1186,7 @@ def tensor_copy_slice_from_tuple(data, tuple_index, value):
|
|
|
974
1186
|
dim1_start, dim1_stop, _ = get_slice_stride(tuple_index[1], data_shape[1])
|
|
975
1187
|
if dim1_stop - dim1_start <= 0:
|
|
976
1188
|
return data
|
|
977
|
-
dim0_start =
|
|
1189
|
+
dim0_start = _scalar_to_tensor(tuple_index[0])
|
|
978
1190
|
dim0_stop = dim0_start + const_utils.make_tensor(1)
|
|
979
1191
|
start = (dim0_start, dim1_start)
|
|
980
1192
|
stop = (dim0_stop, dim1_stop)
|
|
@@ -986,7 +1198,7 @@ def tensor_copy_slice_from_tuple(data, tuple_index, value):
|
|
|
986
1198
|
if dim1_size > data_shape[1]:
|
|
987
1199
|
dim1_size = data_shape[1:2]
|
|
988
1200
|
value_shape = P.Concat(-1)((dim1_size, data_shape[2:]))
|
|
989
|
-
value =
|
|
1201
|
+
value = ops.broadcast_to(value, value_shape)
|
|
990
1202
|
return copy_slice(data, value.astype(data.dtype), start_tensor, stop_tensor, step_tensor)
|
|
991
1203
|
|
|
992
1204
|
|
|
@@ -996,7 +1208,7 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
|
|
996
1208
|
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
|
997
1209
|
|
|
998
1210
|
if const_utils.use_copy_slice(tuple_index) and not const_utils.is_ascend():
|
|
999
|
-
if
|
|
1211
|
+
if F.is_sequence_value_unknown(F.shape(data)):
|
|
1000
1212
|
return tensor_copy_slice_from_tuple(data, tuple_index, value)
|
|
1001
1213
|
dim1_start, dim1_stop, _ = const_utils.normalize_slice(tuple_index[1], data.shape[1])
|
|
1002
1214
|
if dim1_stop - dim1_start <= 0:
|
|
@@ -1016,7 +1228,6 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
|
|
1016
1228
|
if len(tuple_index) == 1:
|
|
1017
1229
|
data[tuple_index[0]] = value
|
|
1018
1230
|
return data
|
|
1019
|
-
|
|
1020
1231
|
indexes_types = hyper_map(toptypeof, tuple_index)
|
|
1021
1232
|
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
|
|
1022
1233
|
|
|
@@ -1050,14 +1261,20 @@ def tensor_setitem_by_number_with_sequence(data, index, value):
|
|
|
1050
1261
|
def tensor_setitem_by_number_with_tensor(data, index, value):
|
|
1051
1262
|
"""Assigns the tensor by number with tensor value."""
|
|
1052
1263
|
data_shape = F.shape(data)
|
|
1053
|
-
if
|
|
1054
|
-
index =
|
|
1264
|
+
if F.is_sequence_value_unknown(data_shape):
|
|
1265
|
+
index = _scalar_to_tensor(index)
|
|
1055
1266
|
index = F.expand_dims(index, -1)
|
|
1056
1267
|
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value)
|
|
1057
1268
|
|
|
1269
|
+
dim_size = data_shape[0]
|
|
1270
|
+
if index < -dim_size or index >= dim_size:
|
|
1271
|
+
raise IndexError(f'index {index} is out of bounds for axis 0 with size {dim_size}')
|
|
1058
1272
|
index = const_utils.int_to_index(index, data_shape)
|
|
1059
1273
|
value_shape = const_utils.tuple_slice(F.shape(index), None, -1)
|
|
1060
1274
|
value = _broadcast(value_shape, value.astype(F.dtype(data)))
|
|
1275
|
+
if is_parameter(data):
|
|
1276
|
+
F.scatter_nd_update(data, index, value)
|
|
1277
|
+
return data
|
|
1061
1278
|
return F.tensor_scatter_update(data, index, value)
|
|
1062
1279
|
|
|
1063
1280
|
|
|
@@ -1065,7 +1282,7 @@ def tensor_setitem_by_ellipsis_with_number(data, value):
|
|
|
1065
1282
|
"""Assigns the tensor by ellipsis with number value."""
|
|
1066
1283
|
data_shape = F.shape(data)
|
|
1067
1284
|
data_dtype = F.dtype(data)
|
|
1068
|
-
if
|
|
1285
|
+
if F.is_sequence_value_unknown(data_shape):
|
|
1069
1286
|
value = F.fill(F.dtype(data), (), value)
|
|
1070
1287
|
return tensor_setitem_by_ellipsis_with_tensor(data, value)
|
|
1071
1288
|
return F.fill(data_dtype, data_shape, value)
|
|
@@ -1077,9 +1294,9 @@ def tensor_setitem_by_ellipsis_with_tensor(data, value):
|
|
|
1077
1294
|
data_dtype = F.dtype(data)
|
|
1078
1295
|
value = value.astype(data_dtype)
|
|
1079
1296
|
|
|
1080
|
-
if
|
|
1297
|
+
if F.is_sequence_value_unknown(data_shape):
|
|
1081
1298
|
data_shape = F.dyn_shape(data)
|
|
1082
|
-
data =
|
|
1299
|
+
data = ops.broadcast_to(value, data_shape)
|
|
1083
1300
|
return data
|
|
1084
1301
|
value_shape = F.shape(value)
|
|
1085
1302
|
source_shape = const_utils.get_source_shape(data_shape, value_shape)
|
|
@@ -1107,9 +1324,9 @@ def tensor_setitem_by_bool(data, index, value):
|
|
|
1107
1324
|
elif isinstance(value, float):
|
|
1108
1325
|
value = const_utils.make_tensor(value, mstype.float32)
|
|
1109
1326
|
|
|
1110
|
-
if
|
|
1327
|
+
if F.is_sequence_value_unknown(data_shape) and index:
|
|
1111
1328
|
data_shape = F.dyn_shape(data)
|
|
1112
|
-
data =
|
|
1329
|
+
data = ops.broadcast_to(value, data_shape)
|
|
1113
1330
|
return data
|
|
1114
1331
|
value_shape = F.shape(value)
|
|
1115
1332
|
source_shape = const_utils.get_source_shape(data_shape, value_shape)
|
|
@@ -1135,6 +1352,8 @@ def format_list_indices(list_indices, length):
|
|
|
1135
1352
|
# If eyery element in list is bool, it's treated as 1-D bool tensor.
|
|
1136
1353
|
# If every element in list is int(not all bool), it's treated as int tensor.
|
|
1137
1354
|
if const_utils.judge_indexes_types(indices_types, mstype.int_type + (mstype.bool_,)):
|
|
1355
|
+
if not F.isconstant(length):
|
|
1356
|
+
return const_utils.sequence_to_index(list_indices, None)
|
|
1138
1357
|
return const_utils.sequence_to_index(list_indices, length)
|
|
1139
1358
|
# If list contains other types(.../list/tuple/None), it's treated as a tuple
|
|
1140
1359
|
return const_utils.deep_tuple(list_indices)
|
|
@@ -1154,11 +1373,34 @@ def format_tuple_indices(tuple_indices):
|
|
|
1154
1373
|
return res
|
|
1155
1374
|
|
|
1156
1375
|
|
|
1376
|
+
@_primexpr
|
|
1377
|
+
def remove_expanded_dims_parse_bool_tensor_index(index_out, indices_out, shapes, cur_dim):
|
|
1378
|
+
""" Parse bool tensor index """
|
|
1379
|
+
index_out = index_out.nonzero()
|
|
1380
|
+
if index_out.shape[0] == 0:
|
|
1381
|
+
return None, shapes, cur_dim
|
|
1382
|
+
for i in range(index_out.shape[1]):
|
|
1383
|
+
out = index_out[:, i]
|
|
1384
|
+
indices_out += (out,)
|
|
1385
|
+
shapes.append(F.shape(out))
|
|
1386
|
+
cur_dim += 1
|
|
1387
|
+
return indices_out, shapes, cur_dim
|
|
1388
|
+
|
|
1389
|
+
|
|
1390
|
+
def remove_expanded_dims_parse_tensor_index(index_out, indices_out, shapes, cur_dim):
|
|
1391
|
+
""" Parse tensor index """
|
|
1392
|
+
if index_out.dtype == mstype.bool_:
|
|
1393
|
+
return remove_expanded_dims_parse_bool_tensor_index(index_out, indices_out, shapes, cur_dim)
|
|
1394
|
+
indices_out += (index_out,)
|
|
1395
|
+
shapes.append(F.shape(index_out))
|
|
1396
|
+
cur_dim += 1
|
|
1397
|
+
return indices_out, shapes, cur_dim
|
|
1398
|
+
|
|
1399
|
+
|
|
1157
1400
|
def remove_expanded_dims(tuple_index, data_shape, value):
|
|
1158
1401
|
"""Removes expanded dimensions in tuple_index and value."""
|
|
1159
|
-
op_name = const_utils.TENSOR_SETITEM
|
|
1160
1402
|
not_expanded_dim = ()
|
|
1161
|
-
shapes =
|
|
1403
|
+
shapes = []
|
|
1162
1404
|
has_true = False
|
|
1163
1405
|
has_false = False
|
|
1164
1406
|
has_sequence = False
|
|
@@ -1185,17 +1427,18 @@ def remove_expanded_dims(tuple_index, data_shape, value):
|
|
|
1185
1427
|
idx_advanced = 0
|
|
1186
1428
|
idx_tensor = i
|
|
1187
1429
|
if isinstance(index_out, Tensor):
|
|
1188
|
-
|
|
1430
|
+
indices_out, shapes, cur_dim = \
|
|
1431
|
+
remove_expanded_dims_parse_tensor_index(index_out, indices_out, shapes, cur_dim)
|
|
1432
|
+
if indices_out is None:
|
|
1433
|
+
return False, value, 0
|
|
1434
|
+
if index_out.dtype != mstype.bool_ and F.rank(index_out) > 0:
|
|
1189
1435
|
has_sequence = True
|
|
1190
|
-
indices_out += (index_out,)
|
|
1191
|
-
shapes += (F.shape(index_out),)
|
|
1192
|
-
cur_dim += 1
|
|
1193
1436
|
has_true = has_true or index_out is True
|
|
1194
1437
|
has_false = has_false or index_out is False
|
|
1195
1438
|
else:
|
|
1196
1439
|
const_utils.raise_index_error('invalid index type')
|
|
1197
1440
|
|
|
1198
|
-
broadcast_shape = const_utils.generate_broadcast_shape(shapes,
|
|
1441
|
+
broadcast_shape = const_utils.generate_broadcast_shape(shapes, const_utils.TENSOR_SETITEM)
|
|
1199
1442
|
if has_false:
|
|
1200
1443
|
if F.shape_mul(broadcast_shape) != 1:
|
|
1201
1444
|
const_utils.raise_index_error('unable to broadcast indices')
|
|
@@ -1222,11 +1465,21 @@ def format_index(idx, data_shape, cur_dim):
|
|
|
1222
1465
|
elif isinstance(idx, int) and not isinstance(idx, bool):
|
|
1223
1466
|
idx = const_utils.make_tensor(idx, mstype.int64, None, data_shape[cur_dim])
|
|
1224
1467
|
elif isinstance(idx, Tensor):
|
|
1225
|
-
|
|
1226
|
-
|
|
1468
|
+
tensor_dtype = const_utils.get_index_tensor_dtype(idx.dtype)
|
|
1469
|
+
if tensor_dtype == const_utils.INT_:
|
|
1470
|
+
idx = F.select(idx < 0, idx + data_shape[cur_dim], idx)
|
|
1471
|
+
elif tensor_dtype == const_utils.BOOL_:
|
|
1472
|
+
# index with tensor(bool) type is processed in remove_expanded_dims()
|
|
1473
|
+
pass
|
|
1227
1474
|
return idx
|
|
1228
1475
|
|
|
1229
1476
|
|
|
1477
|
+
@_primexpr
|
|
1478
|
+
def _check_shape_mul(shape):
|
|
1479
|
+
if F.shape_mul(shape) == 0:
|
|
1480
|
+
raise ValueError('zero-size tensors are not supported.')
|
|
1481
|
+
|
|
1482
|
+
|
|
1230
1483
|
def reduce_(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None, where=True, dtype=None):
|
|
1231
1484
|
"""
|
|
1232
1485
|
Applies comparison based on cmp_fn and reduction based on reduce_fn.
|
|
@@ -1243,8 +1496,7 @@ def reduce_(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None,
|
|
|
1243
1496
|
not isinstance(initial, (int, float, bool, Tensor))):
|
|
1244
1497
|
const_utils.raise_type_error('initial must be scalar')
|
|
1245
1498
|
|
|
1246
|
-
|
|
1247
|
-
const_utils.raise_value_error('zero-size tensors are not supported.')
|
|
1499
|
+
_check_shape_mul(shape)
|
|
1248
1500
|
|
|
1249
1501
|
if initial is not None:
|
|
1250
1502
|
if isinstance(initial, Tensor):
|