mindspore 2.0.0a0__cp37-cp37m-win_amd64.whl → 2.0.0rc1__cp37-cp37m-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -2
- mindspore/_c_dataengine.cp37-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp37-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp37-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +16 -1
- mindspore/_extends/parse/parser.py +107 -22
- mindspore/_extends/parse/resources.py +0 -7
- mindspore/_extends/parse/standard_method.py +885 -413
- mindspore/amp.py +52 -57
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +38 -20
- mindspore/boost/dim_reduce.py +3 -3
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -6
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +41 -7
- mindspore/common/api.py +215 -141
- mindspore/common/dtype.py +8 -1
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +4 -2
- mindspore/common/jit_config.py +17 -13
- mindspore/common/mutable.py +33 -13
- mindspore/common/parameter.py +23 -21
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +62 -41
- mindspore/common/tensor.py +852 -1154
- mindspore/communication/__init__.py +2 -2
- mindspore/communication/_comm_helper.py +11 -4
- mindspore/communication/management.py +22 -21
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +201 -23
- mindspore/dataset/__init__.py +6 -6
- mindspore/dataset/audio/__init__.py +7 -7
- mindspore/dataset/audio/transforms.py +670 -30
- mindspore/dataset/audio/utils.py +47 -4
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +210 -14
- mindspore/dataset/core/validator_helpers.py +2 -2
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +322 -66
- mindspore/dataset/engine/datasets_audio.py +80 -76
- mindspore/dataset/engine/datasets_standard_format.py +51 -38
- mindspore/dataset/engine/datasets_text.py +232 -118
- mindspore/dataset/engine/datasets_user_defined.py +41 -17
- mindspore/dataset/engine/datasets_vision.py +746 -225
- mindspore/dataset/engine/graphdata.py +75 -10
- mindspore/dataset/engine/iterators.py +45 -5
- mindspore/dataset/engine/offload.py +48 -28
- mindspore/dataset/engine/validators.py +117 -8
- mindspore/dataset/text/__init__.py +6 -5
- mindspore/dataset/text/transforms.py +86 -3
- mindspore/dataset/text/utils.py +6 -4
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +3 -2
- mindspore/dataset/transforms/c_transforms.py +1 -1
- mindspore/dataset/transforms/transforms.py +2 -2
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +2 -3
- mindspore/dataset/vision/c_transforms.py +9 -9
- mindspore/dataset/vision/py_transforms.py +5 -5
- mindspore/dataset/vision/py_transforms_util.py +2 -0
- mindspore/dataset/vision/transforms.py +160 -161
- mindspore/dataset/vision/utils.py +3 -3
- mindspore/experimental/map_parameter.py +38 -26
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +44 -9
- mindspore/include/api/delegate.h +1 -1
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_parallel_runner.h +2 -2
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +19 -3
- mindspore/include/api/types.h +3 -3
- mindspore/include/dataset/constants.h +7 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filereader.py +18 -0
- mindspore/mindrecord/filewriter.py +197 -34
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
- mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
- mindspore/mindrecord/tools/csv_to_mr.py +3 -3
- mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/__init__.py +0 -4
- mindspore/nn/cell.py +204 -132
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +7 -6
- mindspore/nn/layer/__init__.py +5 -4
- mindspore/nn/layer/activation.py +40 -89
- mindspore/nn/layer/basic.py +255 -624
- mindspore/nn/layer/channel_shuffle.py +7 -6
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +41 -4
- mindspore/nn/layer/conv.py +64 -28
- mindspore/nn/layer/dense.py +9 -8
- mindspore/nn/layer/embedding.py +27 -25
- mindspore/nn/layer/image.py +53 -46
- mindspore/nn/layer/math.py +97 -105
- mindspore/nn/layer/normalization.py +117 -86
- mindspore/nn/layer/padding.py +185 -95
- mindspore/nn/layer/pooling.py +817 -414
- mindspore/nn/layer/rnn_cells.py +10 -15
- mindspore/nn/layer/rnns.py +37 -38
- mindspore/nn/layer/thor_layer.py +11 -12
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +5 -4
- mindspore/nn/loss/loss.py +334 -199
- mindspore/nn/optim/ada_grad.py +6 -6
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +4 -5
- mindspore/nn/optim/adam.py +126 -62
- mindspore/nn/optim/adamax.py +3 -4
- mindspore/nn/optim/adasum.py +6 -6
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +67 -38
- mindspore/nn/optim/lamb.py +4 -5
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +43 -4
- mindspore/nn/optim/momentum.py +6 -5
- mindspore/nn/optim/optimizer.py +3 -1
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +1 -1
- mindspore/nn/optim/rprop.py +8 -9
- mindspore/nn/optim/sgd.py +19 -13
- mindspore/nn/optim/thor.py +10 -15
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +4 -4
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +9 -15
- mindspore/nn/probability/distribution/bernoulli.py +3 -3
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +5 -7
- mindspore/nn/probability/distribution/cauchy.py +3 -3
- mindspore/nn/probability/distribution/distribution.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +3 -3
- mindspore/nn/probability/distribution/half_normal.py +15 -11
- mindspore/nn/probability/distribution/laplace.py +16 -13
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/normal.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/student_t.py +20 -15
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +27 -10
- mindspore/nn/wrap/grad_reducer.py +2 -2
- mindspore/nn/wrap/loss_scale.py +40 -24
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +35 -30
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +22 -19
- mindspore/numpy/utils.py +1 -1
- mindspore/numpy/utils_const.py +108 -58
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +86 -117
- mindspore/ops/_grad/grad_base.py +23 -1
- mindspore/ops/_grad/grad_clip_ops.py +2 -3
- mindspore/ops/_grad/grad_comm_ops.py +34 -24
- mindspore/ops/_grad/grad_implementations.py +9 -45
- mindspore/ops/_grad/grad_inner_ops.py +47 -4
- mindspore/ops/_grad/grad_math_ops.py +142 -117
- mindspore/ops/_grad/grad_nn_ops.py +71 -165
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +7 -6
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
- mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
- mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
- mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -611
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_register_for_op.py +1 -0
- mindspore/ops/_utils/__init__.py +1 -2
- mindspore/ops/_utils/utils.py +19 -40
- mindspore/ops/_vmap/vmap_array_ops.py +116 -38
- mindspore/ops/_vmap/vmap_base.py +16 -9
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
- mindspore/ops/_vmap/vmap_image_ops.py +12 -5
- mindspore/ops/_vmap/vmap_math_ops.py +46 -5
- mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
- mindspore/ops/_vmap/vmap_random_ops.py +1 -1
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
- mindspore/ops/composite/__init__.py +7 -8
- mindspore/ops/composite/base.py +101 -47
- mindspore/ops/composite/math_ops.py +188 -158
- mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
- mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
- mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
- mindspore/ops/function/__init__.py +152 -8
- mindspore/ops/function/array_func.py +2555 -674
- mindspore/ops/function/clip_func.py +209 -13
- mindspore/ops/function/debug_func.py +2 -2
- mindspore/ops/function/grad/__init__.py +2 -1
- mindspore/ops/function/grad/grad_func.py +147 -62
- mindspore/ops/function/image_func.py +54 -38
- mindspore/ops/function/linalg_func.py +167 -16
- mindspore/ops/function/math_func.py +4849 -1492
- mindspore/ops/function/nn_func.py +2573 -988
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +3 -3
- mindspore/ops/function/random_func.py +790 -73
- mindspore/ops/function/sparse_func.py +98 -78
- mindspore/ops/function/sparse_unary_func.py +54 -53
- mindspore/ops/function/spectral_func.py +27 -24
- mindspore/ops/function/vmap_func.py +22 -2
- mindspore/ops/functional.py +97 -37
- mindspore/ops/op_info_register.py +70 -28
- mindspore/ops/operations/__init__.py +47 -14
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +5 -5
- mindspore/ops/operations/_grad_ops.py +276 -187
- mindspore/ops/operations/_inner_ops.py +319 -113
- mindspore/ops/operations/_ms_kernel.py +10 -8
- mindspore/ops/operations/_ocr_ops.py +9 -9
- mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
- mindspore/ops/operations/_quant_ops.py +137 -102
- mindspore/ops/operations/_rl_inner_ops.py +121 -60
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1004 -2
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +801 -466
- mindspore/ops/operations/comm_ops.py +51 -49
- mindspore/ops/operations/control_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +123 -44
- mindspore/ops/operations/debug_ops.py +24 -24
- mindspore/ops/operations/image_ops.py +240 -153
- mindspore/ops/operations/inner_ops.py +34 -50
- mindspore/ops/operations/linalg_ops.py +31 -9
- mindspore/ops/operations/math_ops.py +988 -757
- mindspore/ops/operations/nn_ops.py +965 -819
- mindspore/ops/operations/other_ops.py +51 -40
- mindspore/ops/operations/random_ops.py +204 -122
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +254 -93
- mindspore/ops/operations/spectral_ops.py +35 -3
- mindspore/ops/primitive.py +111 -9
- mindspore/parallel/_auto_parallel_context.py +189 -83
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +99 -7
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +7 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
- mindspore/parallel/_utils.py +1 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +37 -34
- mindspore/parallel/shard.py +17 -18
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +69 -47
- mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
- mindspore/profiler/parser/base_timeline_generator.py +49 -56
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
- mindspore/profiler/parser/hwts_log_parser.py +1 -1
- mindspore/profiler/parser/integrator.py +15 -14
- mindspore/profiler/parser/minddata_analyzer.py +2 -2
- mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +2 -1
- mindspore/profiler/profiling.py +218 -186
- mindspore/rewrite/__init__.py +3 -1
- mindspore/rewrite/api/node.py +1 -114
- mindspore/rewrite/api/node_type.py +3 -0
- mindspore/rewrite/api/pattern_engine.py +31 -1
- mindspore/rewrite/api/scoped_value.py +4 -4
- mindspore/rewrite/api/symbol_tree.py +3 -78
- mindspore/rewrite/api/tree_node_helper.py +1 -1
- mindspore/rewrite/ast_creator_register.py +1 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -2
- mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
- mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
- mindspore/rewrite/namespace.py +0 -2
- mindspore/rewrite/node.py +157 -11
- mindspore/rewrite/parsers/assign_parser.py +231 -53
- mindspore/rewrite/parsers/class_def_parser.py +187 -109
- mindspore/rewrite/parsers/for_parser.py +24 -14
- mindspore/rewrite/parsers/function_def_parser.py +21 -4
- mindspore/rewrite/parsers/if_parser.py +6 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +256 -133
- mindspore/rewrite/symbol_tree_builder.py +38 -1
- mindspore/run_check/_check_version.py +69 -63
- mindspore/run_check/run_check.py +2 -1
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +1 -1
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +273 -102
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +2 -2
- mindspore/train/callback/_checkpoint.py +3 -3
- mindspore/train/callback/_early_stop.py +3 -3
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +29 -31
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +3 -3
- mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
- mindspore/train/callback/_summary_collector.py +23 -16
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +15 -3
- mindspore/train/dataset_helper.py +10 -15
- mindspore/train/loss_scale_manager.py +8 -11
- mindspore/train/metrics/__init__.py +1 -1
- mindspore/train/metrics/bleu_score.py +1 -1
- mindspore/train/metrics/confusion_matrix.py +1 -1
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/dice.py +2 -2
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +4 -3
- mindspore/train/metrics/mean_surface_distance.py +2 -2
- mindspore/train/metrics/occlusion_sensitivity.py +1 -1
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +1 -1
- mindspore/train/metrics/recall.py +1 -1
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +45 -28
- mindspore/train/serialization.py +295 -188
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -13
- mindspore/train/train_thor/convert_utils.py +2 -2
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
- mindspore/compression/__init__.py +0 -19
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -515
- mindspore/compression/quant/__init__.py +0 -28
- mindspore/compression/quant/qat.py +0 -634
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -140
- mindspore/nn/probability/dpn/vae/vae.py +0 -124
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
- mindspore/ops/composite/array_ops.py +0 -241
- mindspore/ops/composite/clip_ops.py +0 -134
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2020-
|
|
1
|
+
# Copyright 2020-2023 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -15,25 +15,32 @@
|
|
|
15
15
|
|
|
16
16
|
"""constexpr util"""
|
|
17
17
|
from __future__ import absolute_import
|
|
18
|
+
from enum import IntEnum
|
|
19
|
+
|
|
18
20
|
|
|
19
21
|
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
|
|
20
22
|
from mindspore.ops import functional as F
|
|
21
23
|
from mindspore.ops import operations as P
|
|
22
24
|
from mindspore.ops.composite import base
|
|
23
25
|
from mindspore.ops._primitive_cache import _get_cache_prim
|
|
24
|
-
from mindspore.ops.operations._inner_ops import TensorCopySlices, SliceGetItem,
|
|
25
|
-
TopTypeof, issubclass_
|
|
26
|
+
from mindspore.ops.operations._inner_ops import TensorCopySlices, SliceGetItem, \
|
|
27
|
+
TopTypeof, issubclass_, IsParameter, GetitemTensorIndexInfo, SetitemTensorIndexInfo
|
|
26
28
|
from mindspore.common import dtype as mstype
|
|
27
29
|
from mindspore.common._register_for_tensor import tensor_operator_registry
|
|
30
|
+
from mindspore.common.initializer import Zero
|
|
28
31
|
from mindspore.common import Tensor, CSRTensor, COOTensor
|
|
29
|
-
from mindspore.common
|
|
32
|
+
from mindspore.common import mutable
|
|
33
|
+
from mindspore import ops
|
|
34
|
+
from mindspore.ops.primitive import _primexpr
|
|
30
35
|
|
|
31
36
|
slice_get_item = SliceGetItem()
|
|
32
37
|
hyper_map = base.HyperMap()
|
|
33
38
|
stack = P.Stack(axis=-1)
|
|
34
39
|
copy_slice = TensorCopySlices()
|
|
35
|
-
dynamic_broadcast_to = DynamicBroadcastTo()
|
|
36
40
|
toptypeof = TopTypeof()
|
|
41
|
+
is_parameter = IsParameter()
|
|
42
|
+
getitem_tensor_index_info = GetitemTensorIndexInfo(const_utils.is_ascend())
|
|
43
|
+
setitem_tensor_index_info = SetitemTensorIndexInfo(const_utils.is_ascend())
|
|
37
44
|
|
|
38
45
|
|
|
39
46
|
def strided_slice(data, begin_strides, end_strides, step_strides, begin_mask=0, end_mask=0, ellipsis_mask=0,
|
|
@@ -44,50 +51,138 @@ def strided_slice(data, begin_strides, end_strides, step_strides, begin_mask=0,
|
|
|
44
51
|
return strided_slice_(data, begin_strides, end_strides, step_strides)
|
|
45
52
|
|
|
46
53
|
|
|
54
|
+
class ValueTransferType(IntEnum):
|
|
55
|
+
"""Transfer op types of handling tensor getitem/setitem"""
|
|
56
|
+
kUnknown = 0
|
|
57
|
+
kTensorScatterUpdate = 1
|
|
58
|
+
kExpandDims = 2
|
|
59
|
+
kBroadCast = 3
|
|
60
|
+
kCast = 4
|
|
61
|
+
kSelect = 5
|
|
62
|
+
kGather = 6
|
|
63
|
+
kStrideSlice = 7
|
|
64
|
+
kStrideSliceWithMask = 8
|
|
65
|
+
kGatherND = 9
|
|
66
|
+
kScatterNdUpdate = 10
|
|
67
|
+
kReshape = 11
|
|
68
|
+
kScatterND = 12
|
|
69
|
+
kNumberToTensor = 13
|
|
70
|
+
kHandleSequenceValue = 14
|
|
71
|
+
kByPass = 15
|
|
72
|
+
kReSetItemByIndex = 16
|
|
73
|
+
kCopySlice = 17
|
|
74
|
+
kSetItemByBool = 18
|
|
75
|
+
kEmptyTensor = 19
|
|
76
|
+
kSetItemByEllipsis = 20
|
|
77
|
+
kRaiseIndexError = 21
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def data_update(transfer_types, args, data, new_index, value=None):
|
|
81
|
+
"""
|
|
82
|
+
We finally generate a new tensor when handling tensor getitem/setitem
|
|
83
|
+
by transfer data and value with index.
|
|
84
|
+
"""
|
|
85
|
+
for transfer_type, arg in zip(transfer_types, args):
|
|
86
|
+
if transfer_type == ValueTransferType.kUnknown:
|
|
87
|
+
raise IndexError(f"Inlvaid transfer type {transfer_type}.")
|
|
88
|
+
if transfer_type <= ValueTransferType.kScatterND:
|
|
89
|
+
data = data_update_by_ops(transfer_type, arg, data, new_index, value)
|
|
90
|
+
if transfer_type == ValueTransferType.kSetItemByBool:
|
|
91
|
+
return tensor_setitem_by_bool(data, new_index, value)
|
|
92
|
+
if transfer_type == ValueTransferType.kCopySlice:
|
|
93
|
+
return copy_slice(data, value.astype(data.dtype), arg[0], arg[1], arg[2])
|
|
94
|
+
if transfer_type == ValueTransferType.kSetItemByEllipsis:
|
|
95
|
+
return tensor_setitem_by_ellipsis(data, new_index, value)
|
|
96
|
+
if transfer_type == ValueTransferType.kReSetItemByIndex:
|
|
97
|
+
data[new_index] = value
|
|
98
|
+
return data
|
|
99
|
+
if transfer_type == ValueTransferType.kEmptyTensor:
|
|
100
|
+
return handle_empty_tensor(arg, data)
|
|
101
|
+
if transfer_type == ValueTransferType.kRaiseIndexError:
|
|
102
|
+
raise IndexError(
|
|
103
|
+
f'index {arg[0]} is out of bounds for dimension with size {arg[1]}')
|
|
104
|
+
return data
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def data_update_by_ops(transfer_type, arg, data, new_index, value=None):
|
|
108
|
+
"""
|
|
109
|
+
Generate a new tensor when handling tensor getitem/setitem
|
|
110
|
+
by ops.
|
|
111
|
+
"""
|
|
112
|
+
if transfer_type == ValueTransferType.kStrideSliceWithMask:
|
|
113
|
+
stride_info, mask_index = arg[0], arg[1]
|
|
114
|
+
data = strided_slice(data, stride_info[0], stride_info[1], stride_info[2],
|
|
115
|
+
mask_index[0], mask_index[1], 0, 0, mask_index[2])
|
|
116
|
+
elif transfer_type == ValueTransferType.kGatherND:
|
|
117
|
+
if isinstance(new_index, list):
|
|
118
|
+
new_index = handle_multi_dim_index_tensor(new_index, arg)
|
|
119
|
+
data = F.gather_nd(data, Tensor(new_index))
|
|
120
|
+
elif transfer_type == ValueTransferType.kTensorScatterUpdate:
|
|
121
|
+
if isinstance(new_index, list):
|
|
122
|
+
new_index = handle_multi_dim_index_tensor(new_index, arg)
|
|
123
|
+
data = F.tensor_scatter_update(data, new_index, value)
|
|
124
|
+
elif transfer_type == ValueTransferType.kScatterNdUpdate:
|
|
125
|
+
F.scatter_nd_update(data, new_index, value)
|
|
126
|
+
elif transfer_type == ValueTransferType.kSelect:
|
|
127
|
+
data = F.select(Tensor(new_index), value, data)
|
|
128
|
+
elif transfer_type == ValueTransferType.kReshape:
|
|
129
|
+
data = F.reshape(data, arg)
|
|
130
|
+
elif transfer_type == ValueTransferType.kGather:
|
|
131
|
+
data = F.gather(data, new_index, 0)
|
|
132
|
+
elif transfer_type == ValueTransferType.kExpandDims:
|
|
133
|
+
data = F.expand_dims(data, 0)
|
|
134
|
+
elif transfer_type == ValueTransferType.kStrideSlice:
|
|
135
|
+
data = F.strided_slice(data, arg[0], arg[1], arg[2])
|
|
136
|
+
else:
|
|
137
|
+
raise IndexError(f"Inlvaid transfer type {transfer_type}.")
|
|
138
|
+
return data
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def value_update(transfer_types, args, data, value):
|
|
142
|
+
"""Transfer value before set value to tensor when handling tensor setitem"""
|
|
143
|
+
for transfer_type, arg in zip(transfer_types, args):
|
|
144
|
+
if transfer_type == ValueTransferType.kByPass:
|
|
145
|
+
continue
|
|
146
|
+
if transfer_type == ValueTransferType.kNumberToTensor:
|
|
147
|
+
value = F.fill(F.dtype(data), (), value)
|
|
148
|
+
elif transfer_type == ValueTransferType.kHandleSequenceValue:
|
|
149
|
+
op_type, index = arg
|
|
150
|
+
if op_type == const_utils.SET_ITEM_BY_ONE_TENSOR:
|
|
151
|
+
index = Tensor(index)
|
|
152
|
+
value = _generate_updates_from_sequence(
|
|
153
|
+
data, index, value, op_type)
|
|
154
|
+
elif transfer_type == ValueTransferType.kExpandDims:
|
|
155
|
+
value = F.expand_dims(value, arg)
|
|
156
|
+
elif transfer_type == ValueTransferType.kBroadCast:
|
|
157
|
+
value = _broadcast(arg, value.astype(F.dtype(data)))
|
|
158
|
+
elif transfer_type == ValueTransferType.kCast:
|
|
159
|
+
value = F.cast(value, F.dtype(data))
|
|
160
|
+
elif transfer_type == ValueTransferType.kReshape:
|
|
161
|
+
value = F.reshape(value, arg)
|
|
162
|
+
elif transfer_type == ValueTransferType.kScatterND:
|
|
163
|
+
value = F.scatter_nd(arg[0], value, arg[1])
|
|
164
|
+
else:
|
|
165
|
+
raise IndexError(f"Inlvaid transfer type {transfer_type}.")
|
|
166
|
+
return value
|
|
167
|
+
|
|
168
|
+
|
|
47
169
|
def _tensor_getitem(self, index):
|
|
48
170
|
"""Handle tensor getitem"""
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
return tensor_index_by_list(self, index)
|
|
53
|
-
if isinstance(index, tuple):
|
|
54
|
-
return tensor_index_by_tuple(self, index)
|
|
55
|
-
if isinstance(index, bool):
|
|
56
|
-
return _tensor_index_by_bool(self, index)
|
|
57
|
-
if isinstance(index, int):
|
|
58
|
-
return _tensor_index_by_integer(self, index)
|
|
59
|
-
if isinstance(index, slice):
|
|
60
|
-
return tensor_index_by_slice(self, index)
|
|
61
|
-
if index is None:
|
|
62
|
-
return F.expand_dims(self, 0)
|
|
63
|
-
if index is ...:
|
|
64
|
-
return self
|
|
65
|
-
raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool, tensor with int, "
|
|
66
|
-
f"list and tuple ,but got {index} with type {type(index)}.")
|
|
171
|
+
new_index, tensor_update_types, tensor_update_args = getitem_tensor_index_info(
|
|
172
|
+
self, index)
|
|
173
|
+
return data_update(tensor_update_types, tensor_update_args, self, new_index)
|
|
67
174
|
|
|
68
175
|
|
|
69
176
|
def _tensor_setitem(self, index, value):
|
|
70
177
|
"""Handle tensor setitem"""
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
return tensor_setitem_by_tuple(self, index, value)
|
|
80
|
-
if isinstance(index, bool):
|
|
81
|
-
return tensor_setitem_by_bool(self, index, value)
|
|
82
|
-
if isinstance(index, int):
|
|
83
|
-
return tensor_setitem_by_number(self, index, value)
|
|
84
|
-
if isinstance(index, slice):
|
|
85
|
-
return tensor_setitem_by_slice(self, index, value)
|
|
86
|
-
if index in (None, ...):
|
|
87
|
-
return tensor_setitem_by_ellipsis(self, index, value)
|
|
88
|
-
|
|
89
|
-
raise IndexError("Tensor setitem index only support integers, slices(`:`), ellipsis(`...`), bool, tensor, \
|
|
90
|
-
list and tuple, but got {index} with type{type(index)}")
|
|
178
|
+
setitem_info = setitem_tensor_index_info(self, index, value)
|
|
179
|
+
new_index = setitem_info[0]
|
|
180
|
+
v_transfer_types = setitem_info[1]
|
|
181
|
+
v_transfer_args = setitem_info[2]
|
|
182
|
+
data_update_types = setitem_info[3]
|
|
183
|
+
data_update_args = setitem_info[4]
|
|
184
|
+
value = value_update(v_transfer_types, v_transfer_args, self, value)
|
|
185
|
+
return data_update(data_update_types, data_update_args, self, new_index, value)
|
|
91
186
|
|
|
92
187
|
|
|
93
188
|
tensor_operator_registry.register("__getitem__", _tensor_getitem)
|
|
@@ -171,6 +266,13 @@ tensor_operator_registry.register('__rpow__', _tensor_rpow)
|
|
|
171
266
|
tensor_operator_registry.register('__floordiv__', _tensor_floordiv)
|
|
172
267
|
|
|
173
268
|
|
|
269
|
+
def _scalar_to_tensor(input_x):
|
|
270
|
+
if ops.isconstant(input_x):
|
|
271
|
+
return P.ScalarToTensor()(input_x, ops.dtype(input_x))
|
|
272
|
+
# use add Tensor([0]) cast scalar to tensor.
|
|
273
|
+
return ops.add(input_x, mutable(Tensor(0)))
|
|
274
|
+
|
|
275
|
+
|
|
174
276
|
def tensor_item(data, *args):
|
|
175
277
|
"""Tensor getitem by index whose dtype is int or tuple with int."""
|
|
176
278
|
# transform a.item(tuple(int)) -> a.item(int1,int2...intN)
|
|
@@ -245,13 +347,9 @@ def tensor_itemset_by_tuple_with_number(data, tuple_index, nubmer_value):
|
|
|
245
347
|
|
|
246
348
|
def _broadcast(broadcast_shape, x):
|
|
247
349
|
"""Broadcast tensor to the required shape."""
|
|
248
|
-
if
|
|
350
|
+
if F.shape(x) == broadcast_shape:
|
|
249
351
|
return x
|
|
250
|
-
|
|
251
|
-
if multiples:
|
|
252
|
-
x = F.reshape(x, const_utils.expanded_shape(F.shape(x), len(multiples) - F.rank(x)))
|
|
253
|
-
return F.tile(x, multiples)
|
|
254
|
-
return x
|
|
352
|
+
return F.broadcast_to(x, broadcast_shape)
|
|
255
353
|
|
|
256
354
|
|
|
257
355
|
def _transform_indexing_tensor(broadcast_shape, final_shape, new_shape, item):
|
|
@@ -291,6 +389,46 @@ def _transform_ellipsis_to_slice(data, tuple_index, op_name):
|
|
|
291
389
|
return tuple_index_new
|
|
292
390
|
|
|
293
391
|
|
|
392
|
+
def handle_empty_tensor(arg, data):
|
|
393
|
+
"""handle data update with empty tensor"""
|
|
394
|
+
if 0 in arg:
|
|
395
|
+
init_func = Zero()
|
|
396
|
+
init_func.__enable_zero_dim__ = True
|
|
397
|
+
return Tensor(shape=arg, dtype=data.dtype, init=init_func)
|
|
398
|
+
return const_utils.make_tensor([], data.dtype, arg)
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def handle_multi_dim_index_tensor(new_index, arg):
|
|
402
|
+
"""handle data update with multi dim index tensor"""
|
|
403
|
+
slice_cnt = 0
|
|
404
|
+
new_indies_tensor = []
|
|
405
|
+
if len(arg) == 1:
|
|
406
|
+
broadcast_shape = arg[0]
|
|
407
|
+
new_index = hyper_map(F.partial(Tensor), new_index)
|
|
408
|
+
broadcast_tensors = hyper_map(
|
|
409
|
+
F.partial(_broadcast, broadcast_shape), new_index)
|
|
410
|
+
new_broadcast_tensors = ()
|
|
411
|
+
for tensor in broadcast_tensors:
|
|
412
|
+
new_broadcast_tensors += (F.cast(tensor, mstype.int64),)
|
|
413
|
+
new_index = stack(new_broadcast_tensors)
|
|
414
|
+
return new_index
|
|
415
|
+
broadcast_shape, final_shape, index_tensor_new_shape, slice_shapes, tensor_positions, fancy_position = arg
|
|
416
|
+
for i, index in enumerate(new_index):
|
|
417
|
+
if i in tensor_positions:
|
|
418
|
+
transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape,
|
|
419
|
+
Tensor(index))
|
|
420
|
+
new_indies_tensor.append(F.cast(transform_tensor, mstype.int64))
|
|
421
|
+
else:
|
|
422
|
+
shape = const_utils.compute_slice_shape(
|
|
423
|
+
slice_shapes, len(broadcast_shape), slice_cnt, fancy_position)
|
|
424
|
+
array = Tensor(index).reshape(shape)
|
|
425
|
+
slice_index_tensor = _broadcast(final_shape, array)
|
|
426
|
+
new_indies_tensor.append(F.cast(slice_index_tensor, mstype.int64))
|
|
427
|
+
slice_cnt += 1
|
|
428
|
+
new_index = stack(new_indies_tensor)
|
|
429
|
+
return new_index
|
|
430
|
+
|
|
431
|
+
|
|
294
432
|
def _expand_data_dims(data, tuple_index):
|
|
295
433
|
"""expand the data's dim with 'None' and 'Boolean' in tuple_index"""
|
|
296
434
|
indexes_types = hyper_map(toptypeof, tuple_index)
|
|
@@ -313,12 +451,34 @@ def _expand_data_dims(data, tuple_index):
|
|
|
313
451
|
return data, tuple_index_new
|
|
314
452
|
|
|
315
453
|
|
|
454
|
+
def convert_variable_to_tensor_slice(slice_index):
|
|
455
|
+
"""convert mutable scalar to tensor"""
|
|
456
|
+
start = slice_get_item(slice_index, "start")
|
|
457
|
+
stop = slice_get_item(slice_index, "stop")
|
|
458
|
+
step = slice_get_item(slice_index, "step")
|
|
459
|
+
find_mutable_scalar = False
|
|
460
|
+
if isinstance(start, int) and not F.isconstant(start):
|
|
461
|
+
start = ops.Cast()(start, mstype.int64)
|
|
462
|
+
find_mutable_scalar = True
|
|
463
|
+
if isinstance(stop, int) and not F.isconstant(stop):
|
|
464
|
+
stop = ops.Cast()(stop, mstype.int64)
|
|
465
|
+
find_mutable_scalar = True
|
|
466
|
+
if isinstance(step, int) and not F.isconstant(step):
|
|
467
|
+
step = ops.Cast()(step, mstype.int64)
|
|
468
|
+
find_mutable_scalar = True
|
|
469
|
+
if find_mutable_scalar:
|
|
470
|
+
return F.make_slice(start, stop, step)
|
|
471
|
+
return slice_index
|
|
472
|
+
|
|
473
|
+
|
|
316
474
|
def tensor_index_by_slice(data, slice_index):
|
|
317
475
|
"""Tensor getitem by a slice."""
|
|
318
476
|
min_data_dim, max_data_dim = 1, 8
|
|
319
477
|
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
320
478
|
data_shape = F.shape(data)
|
|
321
|
-
|
|
479
|
+
slice_index = convert_variable_to_tensor_slice(slice_index)
|
|
480
|
+
|
|
481
|
+
is_dynamic = (F.is_sequence_value_unknown(data_shape)
|
|
322
482
|
or isinstance(slice_get_item(slice_index, "start"), Tensor)
|
|
323
483
|
or isinstance(slice_get_item(slice_index, "stop"), Tensor)
|
|
324
484
|
or isinstance(slice_get_item(slice_index, "step"), Tensor))
|
|
@@ -341,6 +501,12 @@ def get_stride_info_from_slice(data, slice_index):
|
|
|
341
501
|
data_shape = F.dyn_shape(data)
|
|
342
502
|
begin_strides, end_strides, step_strides = [], [], []
|
|
343
503
|
start, stop, step = get_slice_stride(slice_index, data_shape[0])
|
|
504
|
+
if start.ndim > 0:
|
|
505
|
+
start = start.item()
|
|
506
|
+
if stop.ndim > 0:
|
|
507
|
+
stop = stop.item()
|
|
508
|
+
if step.ndim > 0:
|
|
509
|
+
step = step.item()
|
|
344
510
|
begin_strides.append(start)
|
|
345
511
|
end_strides.append(stop)
|
|
346
512
|
step_strides.append(step)
|
|
@@ -370,19 +536,10 @@ def _tensor_index_by_bool(data, bool_value):
|
|
|
370
536
|
return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.")
|
|
371
537
|
|
|
372
538
|
|
|
373
|
-
def check_range(x, dim_size):
|
|
374
|
-
"""Check whether x is within the range of dim_size"""
|
|
375
|
-
tensor_x = const_utils.make_tensor(x)
|
|
376
|
-
if tensor_x >= dim_size or tensor_x < -dim_size:
|
|
377
|
-
return tensor_x
|
|
378
|
-
tensor_x = tensor_x % dim_size
|
|
379
|
-
return tensor_x
|
|
380
|
-
|
|
381
|
-
|
|
382
539
|
def get_stride_info_from_integer(tensor_int):
|
|
383
540
|
"""Convert integer to slice"""
|
|
384
541
|
begin_strides = [tensor_int]
|
|
385
|
-
end_strides = [tensor_int +
|
|
542
|
+
end_strides = [tensor_int + 1]
|
|
386
543
|
step_strides = [const_utils.make_tensor(1)]
|
|
387
544
|
begin_tensor = stack(begin_strides)
|
|
388
545
|
end_tensor = stack(end_strides)
|
|
@@ -398,10 +555,9 @@ def _tensor_index_by_integer(data, int_index):
|
|
|
398
555
|
if data.ndim < 1 or data.ndim > 8:
|
|
399
556
|
const_utils.raise_value_error("Expect Tensor to have dimension between 1 and 8.")
|
|
400
557
|
|
|
401
|
-
if
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
begin_strides, end_strides, step_strides = get_stride_info_from_integer(transformed_tensor)
|
|
558
|
+
if F.is_sequence_value_unknown(data_shape) or not F.isconstant(int_index):
|
|
559
|
+
tensor_index = _scalar_to_tensor(int_index)
|
|
560
|
+
begin_strides, end_strides, step_strides = get_stride_info_from_integer(tensor_index)
|
|
405
561
|
else:
|
|
406
562
|
transformed_number = const_utils.check_range(int_index, data_shape[0])
|
|
407
563
|
begin_strides, end_strides, step_strides = \
|
|
@@ -415,16 +571,35 @@ def _tensor_index_by_integer(data, int_index):
|
|
|
415
571
|
return strided_slice(data, begin_strides, end_strides, step_strides, begin_mask, end_mask, 0, 0, shrink_axis_mask)
|
|
416
572
|
|
|
417
573
|
|
|
574
|
+
def _check_dim_shape_valid(data, tensor_index):
|
|
575
|
+
"""check dim and shape of tensor_index for tensor(bool) indexing"""
|
|
576
|
+
if data.ndim < tensor_index.ndim:
|
|
577
|
+
raise IndexError(f"The dim of index cannot be greater than indexed data, but got "
|
|
578
|
+
f"dim of index:{tensor_index.ndim}, dim of data:{data.ndim}")
|
|
579
|
+
if data.shape[:tensor_index.ndim] != tensor_index.shape[:]:
|
|
580
|
+
raise IndexError(f"The shape of index {tensor_index.shape} does not match the shape "
|
|
581
|
+
f"of the indexed data {data.shape}")
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
def tensor_index_by_bool_tensor(data, tensor_index):
|
|
585
|
+
"""Tensor getitem by a bool tensor"""
|
|
586
|
+
_check_dim_shape_valid(data, tensor_index)
|
|
587
|
+
tensor_index = tensor_index.nonzero()
|
|
588
|
+
return F.gather_nd(data, tensor_index)
|
|
589
|
+
|
|
590
|
+
|
|
418
591
|
def tensor_index_by_tensor(data, tensor_index):
|
|
419
592
|
"""Tensor getitem by a single tensor"""
|
|
420
593
|
min_data_dim, max_data_dim = 0, 7
|
|
421
594
|
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
595
|
+
if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Int):
|
|
596
|
+
return F.gather(data, tensor_index, 0)
|
|
597
|
+
if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Bool):
|
|
598
|
+
return tensor_index_by_bool_tensor(data, tensor_index)
|
|
599
|
+
exp_msg = const_utils.gen_exception_msg(
|
|
600
|
+
"The tensor index must be int or bool type, but got {}.", F.dtype(tensor_index))
|
|
601
|
+
const_utils.raise_index_error(exp_msg)
|
|
602
|
+
return data
|
|
428
603
|
|
|
429
604
|
|
|
430
605
|
def tensor_index_by_list(data, list_index):
|
|
@@ -435,10 +610,13 @@ def tensor_index_by_list(data, list_index):
|
|
|
435
610
|
data_shape = F.shape(data)
|
|
436
611
|
indexes_types = hyper_map(toptypeof, list_index)
|
|
437
612
|
if const_utils.check_type_isinstance(indexes_types, (mstype.Bool, mstype.Int)):
|
|
438
|
-
if data_shape[0]
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
613
|
+
if not F.isconstant(data_shape[0]):
|
|
614
|
+
if all(isinstance(i, bool) for i in list_index):
|
|
615
|
+
const_utils.raise_unimplemented_error(
|
|
616
|
+
"Not supported to the dynamic shape tensor slice by using list of Boolean type")
|
|
617
|
+
tensor_index = const_utils.sequence_to_index(list_index, None)
|
|
618
|
+
else:
|
|
619
|
+
tensor_index = const_utils.sequence_to_index(list_index, data_shape[0])
|
|
442
620
|
if tensor_index is False:
|
|
443
621
|
const_utils.raise_index_error("When tensor is indexed by list, the list can't be empty.")
|
|
444
622
|
return F.gather(data, tensor_index, 0)
|
|
@@ -449,18 +627,28 @@ def tensor_index_by_list(data, list_index):
|
|
|
449
627
|
return tensor_index_by_tuple(data, tuple_index_new)
|
|
450
628
|
|
|
451
629
|
|
|
630
|
+
def convert_tupleslice_to_tensor(tuple_index):
|
|
631
|
+
"""convert mutable scalar in slice to tensor"""
|
|
632
|
+
new_tuple_index = []
|
|
633
|
+
for item in tuple_index:
|
|
634
|
+
if isinstance(item, slice):
|
|
635
|
+
item = convert_variable_to_tensor_slice(item)
|
|
636
|
+
new_tuple_index.append(item)
|
|
637
|
+
return tuple(new_tuple_index)
|
|
638
|
+
|
|
639
|
+
|
|
452
640
|
def tensor_index_by_tuple(data, tuple_index):
|
|
453
641
|
"""Tensor getitem by tuple of various types with None"""
|
|
454
642
|
if not tuple_index:
|
|
455
643
|
return data
|
|
456
644
|
|
|
645
|
+
tuple_index = convert_tupleslice_to_tensor(tuple_index)
|
|
457
646
|
op_name = const_utils.TENSOR_GETITEM
|
|
458
647
|
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
|
459
648
|
data, tuple_index = _expand_data_dims(data, tuple_index)
|
|
460
649
|
|
|
461
650
|
min_data_dim, max_data_dim = 1, 8
|
|
462
651
|
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
463
|
-
|
|
464
652
|
indexes_types = hyper_map(toptypeof, tuple_index)
|
|
465
653
|
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
|
|
466
654
|
if contain_type == const_utils.ALL_BASIC:
|
|
@@ -468,31 +656,6 @@ def tensor_index_by_tuple(data, tuple_index):
|
|
|
468
656
|
return _tensor_getitem_by_tuple(data, tuple_index, op_name)
|
|
469
657
|
|
|
470
658
|
|
|
471
|
-
def _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name):
|
|
472
|
-
"""Tensor getitem by a tuple of tensor."""
|
|
473
|
-
data_shape = F.shape(data)
|
|
474
|
-
tuple_index_len = len(tuple_index)
|
|
475
|
-
|
|
476
|
-
indexes_types = hyper_map(F.dtype, tuple_index)
|
|
477
|
-
const_utils.check_indexes_types_valid(indexes_types, mstype.int_type, op_name)
|
|
478
|
-
tensor_index_shape = hyper_map(F.shape, tuple_index)
|
|
479
|
-
broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name)
|
|
480
|
-
if 0 in broadcast_shape:
|
|
481
|
-
res_shape = broadcast_shape
|
|
482
|
-
if tuple_index_len < len(data_shape):
|
|
483
|
-
res_shape += data_shape[tuple_index_len:]
|
|
484
|
-
res = const_utils.make_tensor([], data.dtype, res_shape)
|
|
485
|
-
return res
|
|
486
|
-
|
|
487
|
-
broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index)
|
|
488
|
-
new_broadcast_tensors = ()
|
|
489
|
-
for tensor in broadcast_tensors:
|
|
490
|
-
new_broadcast_tensors += (F.cast(tensor, mstype.int64),)
|
|
491
|
-
indices = stack(new_broadcast_tensors)
|
|
492
|
-
result = F.gather_nd(data, indices)
|
|
493
|
-
return result
|
|
494
|
-
|
|
495
|
-
|
|
496
659
|
def get_slice_stride(slice_index, dim_size):
|
|
497
660
|
"""Get slice stride info"""
|
|
498
661
|
start = slice_get_item(slice_index, "start")
|
|
@@ -551,7 +714,7 @@ def _get_stride_info_from_tuple(data, tuple_index):
|
|
|
551
714
|
step_strides.append(step)
|
|
552
715
|
index_count = index_count + 1
|
|
553
716
|
elif isinstance(index, int):
|
|
554
|
-
int_tensor =
|
|
717
|
+
int_tensor = _scalar_to_tensor(index)
|
|
555
718
|
begin_strides.append(int_tensor)
|
|
556
719
|
end_strides.append(int_tensor + const_utils.make_tensor(1))
|
|
557
720
|
step_strides.append(const_utils.make_tensor(1))
|
|
@@ -585,7 +748,7 @@ def _get_stride_info_from_tuple(data, tuple_index):
|
|
|
585
748
|
def _tensor_getitem_by_tuple_slice(data, tuple_index):
|
|
586
749
|
"""Tensor getitem by a tuple of slice"""
|
|
587
750
|
data_shape = F.shape(data)
|
|
588
|
-
is_dynamic =
|
|
751
|
+
is_dynamic = F.is_sequence_value_unknown(data_shape)
|
|
589
752
|
for item in tuple_index:
|
|
590
753
|
if isinstance(item, slice):
|
|
591
754
|
is_dynamic = is_dynamic or isinstance(slice_get_item(item, "start"), Tensor) \
|
|
@@ -607,6 +770,39 @@ def _tensor_getitem_by_tuple_slice(data, tuple_index):
|
|
|
607
770
|
return strided_slice(data, begin_v, end_v, step_v, begin_mask, end_mask, 0, 0, shrink_axis_mask)
|
|
608
771
|
|
|
609
772
|
|
|
773
|
+
@_primexpr
|
|
774
|
+
def _tensor_getitem_by_tuple_parse_bool_tensor_index(index, tuple_index_new, tensor_indexes,
|
|
775
|
+
tensor_positions_new):
|
|
776
|
+
""" parse index of bool tensor type """
|
|
777
|
+
indices = index.nonzero()
|
|
778
|
+
if indices.shape[0] == 0:
|
|
779
|
+
return None, tensor_indexes, tensor_positions_new
|
|
780
|
+
indices = F.cast(indices, mstype.int64)
|
|
781
|
+
indices = indices.T
|
|
782
|
+
for sub_index in indices:
|
|
783
|
+
tensor_positions_new.append(len(tuple_index_new))
|
|
784
|
+
tuple_index_new += (sub_index,)
|
|
785
|
+
tensor_indexes.append(sub_index)
|
|
786
|
+
return tuple_index_new, tensor_indexes, tensor_positions_new
|
|
787
|
+
|
|
788
|
+
|
|
789
|
+
def _tensor_getitem_by_tuple_parse_tensor_index(index, tuple_index_new, tensor_indexes, tensor_positions_new):
|
|
790
|
+
""" parse index of tensor type """
|
|
791
|
+
if F.dtype(index) in mstype.int_type:
|
|
792
|
+
tensor_index = F.cast(index, mstype.int64)
|
|
793
|
+
tensor_positions_new.append(len(tuple_index_new))
|
|
794
|
+
tuple_index_new += (tensor_index,)
|
|
795
|
+
tensor_indexes.append(tensor_index)
|
|
796
|
+
elif F.dtype(index) == mstype.bool_:
|
|
797
|
+
return _tensor_getitem_by_tuple_parse_bool_tensor_index(index, tuple_index_new, tensor_indexes,
|
|
798
|
+
tensor_positions_new)
|
|
799
|
+
else:
|
|
800
|
+
exp_msg = const_utils.gen_exception_msg(
|
|
801
|
+
"The tensor element in tuple index must be int or bool type, but got {}.", F.dtype(index))
|
|
802
|
+
const_utils.raise_index_error(exp_msg)
|
|
803
|
+
return tuple_index_new, tensor_indexes, tensor_positions_new
|
|
804
|
+
|
|
805
|
+
|
|
610
806
|
def _tensor_getitem_by_tuple(data, tuple_index, op_name):
|
|
611
807
|
"""Tensor getitem by a tuple of mixed tensor."""
|
|
612
808
|
slice_is_tensor = False
|
|
@@ -617,51 +813,49 @@ def _tensor_getitem_by_tuple(data, tuple_index, op_name):
|
|
|
617
813
|
or isinstance(slice_get_item(item, "step"), Tensor)
|
|
618
814
|
if slice_is_tensor:
|
|
619
815
|
const_utils.raise_index_error("Not supported when slice has tensor")
|
|
620
|
-
|
|
621
|
-
tensor_indexes, slice_indexes = [], []
|
|
816
|
+
|
|
622
817
|
indexes_types = hyper_map(toptypeof, tuple_index)
|
|
623
818
|
slice_positions, _, _, int_positions, _, tensor_positions, sequence_positions = \
|
|
624
819
|
const_utils.get_pos_of_indexes_types(indexes_types, op_name)
|
|
625
|
-
tuple_index_new, slice_shapes = (), ()
|
|
626
820
|
data_shape = F.shape(data)
|
|
821
|
+
tensor_indexes, slice_indexes = [], []
|
|
822
|
+
tuple_index_new, slice_shapes = (), ()
|
|
823
|
+
slice_positions_new, tensor_positions_new = [], []
|
|
627
824
|
for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
|
|
628
825
|
if i in int_positions:
|
|
629
826
|
int_index = const_utils.check_range(index, dim_size)
|
|
630
827
|
tensor_index = F.scalar_to_tensor(int_index, mstype.int64)
|
|
631
|
-
if
|
|
632
|
-
|
|
633
|
-
tensor_index = check_range(index, dyn_shape[i])
|
|
828
|
+
if F.is_sequence_value_unknown(data_shape):
|
|
829
|
+
tensor_index = _scalar_to_tensor(int_index)
|
|
634
830
|
tensor_index = F.cast(tensor_index, mstype.int64)
|
|
831
|
+
tensor_positions_new.append(len(tuple_index_new))
|
|
635
832
|
tuple_index_new += (tensor_index,)
|
|
636
833
|
tensor_indexes.append(tensor_index)
|
|
637
|
-
tensor_positions += (i,)
|
|
638
834
|
elif i in sequence_positions:
|
|
639
835
|
tensor_index = const_utils.sequence_to_index(index, dim_size)
|
|
640
836
|
if tensor_index is False:
|
|
641
837
|
const_utils.raise_index_error("The sequence element(tuple/list) in tuple index can't be empty.")
|
|
838
|
+
tensor_positions_new.append(len(tuple_index_new))
|
|
642
839
|
tuple_index_new += (tensor_index,)
|
|
643
840
|
tensor_indexes.append(tensor_index)
|
|
644
|
-
tensor_positions += (i,)
|
|
645
841
|
elif i in tensor_positions:
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
tensor_index = F.cast(index, mstype.int64)
|
|
652
|
-
tuple_index_new += (tensor_index,)
|
|
653
|
-
tensor_indexes.append(tensor_index)
|
|
842
|
+
tuple_index_new, tensor_indexes, tensor_positions_new = \
|
|
843
|
+
_tensor_getitem_by_tuple_parse_tensor_index(index, tuple_index_new,
|
|
844
|
+
tensor_indexes, tensor_positions_new)
|
|
845
|
+
if tuple_index_new is None:
|
|
846
|
+
return Tensor([])
|
|
654
847
|
elif i in slice_positions:
|
|
655
848
|
slice_ele_list_index = const_utils.transform_slice_to_ele_list(index, dim_size)
|
|
656
849
|
slice_shapes += (len(slice_ele_list_index),)
|
|
850
|
+
slice_positions_new.append(len(tuple_index_new))
|
|
657
851
|
tuple_index_new += (slice_ele_list_index,)
|
|
658
852
|
slice_indexes.append(slice_ele_list_index)
|
|
659
|
-
|
|
660
853
|
tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
|
|
661
854
|
broadcast_shape, index_tensor_new_shape, final_shape, fancy_position = \
|
|
662
|
-
const_utils.generate_index_info_from_tuple_of_mixed_tensors(
|
|
855
|
+
const_utils.generate_index_info_from_tuple_of_mixed_tensors(tensor_positions_new, tensor_indexes_shapes,
|
|
663
856
|
slice_shapes, op_name)
|
|
664
857
|
|
|
858
|
+
tuple_index_len = len(tuple_index)
|
|
665
859
|
if 0 in final_shape + data_shape:
|
|
666
860
|
if tuple_index_len < len(data_shape):
|
|
667
861
|
final_shape = final_shape + data_shape[tuple_index_len:]
|
|
@@ -670,11 +864,11 @@ def _tensor_getitem_by_tuple(data, tuple_index, op_name):
|
|
|
670
864
|
final_index_tensors = []
|
|
671
865
|
slice_cnt = 0
|
|
672
866
|
for i, index in enumerate(tuple_index_new):
|
|
673
|
-
if i in
|
|
867
|
+
if i in tensor_positions_new:
|
|
674
868
|
transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape,
|
|
675
869
|
index)
|
|
676
870
|
final_index_tensors.append(transform_tensor)
|
|
677
|
-
elif i in
|
|
871
|
+
elif i in slice_positions_new:
|
|
678
872
|
slice_index_tensor = convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape,
|
|
679
873
|
slice_shapes, fancy_position)
|
|
680
874
|
final_index_tensors.append(slice_index_tensor)
|
|
@@ -709,7 +903,6 @@ def _generate_indices_from_tuple(data, tuple_index, op_name, fancy_position):
|
|
|
709
903
|
slice_positions, _, _, int_positions, _, tensor_positions, sequence_positions = \
|
|
710
904
|
const_utils.get_pos_of_indexes_types(indexes_types, op_name)
|
|
711
905
|
tuple_index_new, slice_shapes = (), ()
|
|
712
|
-
|
|
713
906
|
for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
|
|
714
907
|
if i in int_positions:
|
|
715
908
|
int_index = const_utils.check_range(index, dim_size)
|
|
@@ -726,7 +919,7 @@ def _generate_indices_from_tuple(data, tuple_index, op_name, fancy_position):
|
|
|
726
919
|
invalid = const_utils.check_type_invalid(F.dtype(index), mstype.int_type)
|
|
727
920
|
if invalid:
|
|
728
921
|
exp_msg = const_utils.gen_exception_msg(
|
|
729
|
-
"The tensor element in tuple index must be int type, but got {}.", F.dtype(index))
|
|
922
|
+
"The tensor element in tuple index must be int or bool type, but got {}.", F.dtype(index))
|
|
730
923
|
const_utils.raise_index_error(exp_msg)
|
|
731
924
|
tensor_index = F.cast(index, mstype.int64)
|
|
732
925
|
tuple_index_new += (tensor_index,)
|
|
@@ -791,11 +984,11 @@ def _generate_updates_from_sequence(data, index, value, op_type):
|
|
|
791
984
|
def _generate_updates_from_tensor(data, index, value, op_type):
|
|
792
985
|
"""Generate an updates tensor from a tensor."""
|
|
793
986
|
value = value.astype(data.dtype)
|
|
794
|
-
if
|
|
987
|
+
if F.is_sequence_value_unknown(F.shape(data)):
|
|
795
988
|
data_shape = F.dyn_shape(data)
|
|
796
989
|
index_shape = F.dyn_shape(index)
|
|
797
990
|
updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type, True)
|
|
798
|
-
updates =
|
|
991
|
+
updates = ops.broadcast_to(value, updates_shape)
|
|
799
992
|
return updates
|
|
800
993
|
updates_shape = const_utils.generate_updates_shape(data.shape, index.shape, op_type, False)
|
|
801
994
|
need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value.shape)
|
|
@@ -815,6 +1008,7 @@ def tensor_setitem_by_tensor(self, index, value):
|
|
|
815
1008
|
|
|
816
1009
|
|
|
817
1010
|
def tensor_setitem_by_tuple(self, index, value):
|
|
1011
|
+
index = convert_tupleslice_to_tensor(index)
|
|
818
1012
|
if isinstance(value, (int, float, bool)):
|
|
819
1013
|
index = format_tuple_indices(index)
|
|
820
1014
|
return tensor_setitem_by_tuple_with_number(self, index, value)
|
|
@@ -832,6 +1026,7 @@ def tensor_setitem_by_number(self, index, value):
|
|
|
832
1026
|
|
|
833
1027
|
|
|
834
1028
|
def tensor_setitem_by_slice(self, index, value):
|
|
1029
|
+
index = convert_variable_to_tensor_slice(index)
|
|
835
1030
|
if isinstance(value, (int, float, bool)):
|
|
836
1031
|
return tensor_setitem_by_slice_with_number(self, index, value)
|
|
837
1032
|
if isinstance(value, Tensor):
|
|
@@ -852,28 +1047,29 @@ def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
|
|
|
852
1047
|
if F.rank(index) == 0:
|
|
853
1048
|
index = F.expand_dims(index, -1)
|
|
854
1049
|
updates = _generate_updates_from_tensor(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
|
|
855
|
-
|
|
1050
|
+
data_shape = F.shape(data)
|
|
1051
|
+
first_val = data_shape[0]
|
|
1052
|
+
if not F.isconstant(first_val):
|
|
1053
|
+
first_val = -1
|
|
1054
|
+
index = F.select(index < 0, index + first_val, index)
|
|
856
1055
|
index = F.expand_dims(index, -1)
|
|
857
1056
|
if F.rank(index) < 2:
|
|
858
1057
|
index = F.expand_dims(index, 0)
|
|
859
1058
|
updates = F.expand_dims(updates, 0)
|
|
1059
|
+
if is_parameter(data):
|
|
1060
|
+
F.scatter_nd_update(data, index, updates)
|
|
1061
|
+
return data
|
|
860
1062
|
return F.tensor_scatter_update(data, index, updates)
|
|
861
1063
|
|
|
862
1064
|
|
|
863
1065
|
def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
|
|
864
1066
|
"""Set a tensor item by a bool tensor with a tensor."""
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
"When assign value is a tensor, its size should be {}, but current size is {}.")
|
|
872
|
-
dtype = F.dtype(data)
|
|
873
|
-
u_cast = F.cast(value, dtype)
|
|
874
|
-
one_data = F.ones_like(data)
|
|
875
|
-
u = F.tensor_mul(one_data, u_cast)
|
|
876
|
-
result = F.select(index, u, data)
|
|
1067
|
+
index = index.reshape(const_utils.generate_padding_shape(index.shape, len(data.shape)))
|
|
1068
|
+
index = F.broadcast_to(index, data.shape)
|
|
1069
|
+
value = F.cast(value, F.dtype(data))
|
|
1070
|
+
value = value.reshape(const_utils.generate_padding_shape(value.shape, len(data.shape)))
|
|
1071
|
+
value = F.broadcast_to(value, data.shape)
|
|
1072
|
+
result = F.select(index, value, data)
|
|
877
1073
|
return result
|
|
878
1074
|
|
|
879
1075
|
|
|
@@ -884,7 +1080,7 @@ def tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
|
|
|
884
1080
|
if tensor_dtype == const_utils.INT_:
|
|
885
1081
|
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
|
|
886
1082
|
|
|
887
|
-
if
|
|
1083
|
+
if F.is_sequence_value_unknown(F.shape(data)):
|
|
888
1084
|
const_utils.raise_unimplemented_error(
|
|
889
1085
|
"Not supported to the dynamic shape tensor slice by using tensor of Boolean type")
|
|
890
1086
|
return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
|
|
@@ -898,11 +1094,13 @@ def tensor_setitem_by_tensor_with_number(data, index, value):
|
|
|
898
1094
|
def tensor_setitem_by_tensor_with_sequence(data, index, value):
|
|
899
1095
|
"""Assigns the tensor by tensor with tuple value."""
|
|
900
1096
|
index_dtype = F.dtype(index)
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
1097
|
+
if index_dtype in (mstype.int32, mstype.int64):
|
|
1098
|
+
return _tensor_setitem_by_tensor_with_sequence(data, index, value)
|
|
1099
|
+
if index_dtype == mstype.bool_:
|
|
1100
|
+
return _tensor_setitem_by_bool_tensor_with_sequence(data, index, value)
|
|
1101
|
+
exp_msg = const_utils.gen_exception_msg("The tensor index must be int or bool type, but got {}.", index_dtype)
|
|
1102
|
+
const_utils.raise_index_error(exp_msg)
|
|
1103
|
+
return None
|
|
906
1104
|
|
|
907
1105
|
|
|
908
1106
|
def _tensor_setitem_by_tensor_with_sequence(data, index, value):
|
|
@@ -912,6 +1110,12 @@ def _tensor_setitem_by_tensor_with_sequence(data, index, value):
|
|
|
912
1110
|
return F.tensor_scatter_update(data, index, updates)
|
|
913
1111
|
|
|
914
1112
|
|
|
1113
|
+
def _tensor_setitem_by_bool_tensor_with_sequence(data, index, value):
|
|
1114
|
+
"""Set a tensor item by a bool tensor with a tuple."""
|
|
1115
|
+
value = sequence_to_tensor(value, F.dtype(data))
|
|
1116
|
+
return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value)
|
|
1117
|
+
|
|
1118
|
+
|
|
915
1119
|
def tensor_setitem_by_slice_with_number(data, input_slice, value):
|
|
916
1120
|
"""Givens a scalar assign to tensor by slice"""
|
|
917
1121
|
value = F.fill(F.dtype(data), (), value)
|
|
@@ -937,7 +1141,7 @@ def tensor_copy_slice_from_slice(data, input_slice, value):
|
|
|
937
1141
|
if dim0_size >= data_shape[0]:
|
|
938
1142
|
dim0_size = data_shape[0:1]
|
|
939
1143
|
value_shape = P.Concat(-1)((dim0_size, data_shape[1:]))
|
|
940
|
-
value =
|
|
1144
|
+
value = ops.broadcast_to(value, value_shape)
|
|
941
1145
|
return copy_slice(data, value.astype(data.dtype), start_tensor, stop_tensor, step_tensor)
|
|
942
1146
|
|
|
943
1147
|
|
|
@@ -948,8 +1152,8 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
|
|
|
948
1152
|
if check_result:
|
|
949
1153
|
data_shape = F.shape(data)
|
|
950
1154
|
step = const_utils.get_step_from_slice(input_slice)
|
|
951
|
-
if step == 1:
|
|
952
|
-
if
|
|
1155
|
+
if step == 1 and not const_utils.is_ascend():
|
|
1156
|
+
if F.is_sequence_value_unknown(data_shape):
|
|
953
1157
|
return tensor_copy_slice_from_slice(data, input_slice, value)
|
|
954
1158
|
start, stop, step = const_utils.normalize_slice(input_slice, data.shape[0])
|
|
955
1159
|
dim0_size = stop - start
|
|
@@ -958,7 +1162,7 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
|
|
|
958
1162
|
value_shape = (dim0_size,) + const_utils.tuple_slice(data.shape, 1, None)
|
|
959
1163
|
value = _broadcast(value_shape, value)
|
|
960
1164
|
return copy_slice(data, value.astype(data.dtype), (start,), (stop,), (step,))
|
|
961
|
-
if
|
|
1165
|
+
if F.is_sequence_value_unknown(data_shape):
|
|
962
1166
|
const_utils.raise_unimplemented_error(
|
|
963
1167
|
"Not supported to take the subscript of dynamic shape tensor slice setitem")
|
|
964
1168
|
indices = const_utils.slice2indices(input_slice, data_shape)
|
|
@@ -982,7 +1186,7 @@ def tensor_copy_slice_from_tuple(data, tuple_index, value):
|
|
|
982
1186
|
dim1_start, dim1_stop, _ = get_slice_stride(tuple_index[1], data_shape[1])
|
|
983
1187
|
if dim1_stop - dim1_start <= 0:
|
|
984
1188
|
return data
|
|
985
|
-
dim0_start =
|
|
1189
|
+
dim0_start = _scalar_to_tensor(tuple_index[0])
|
|
986
1190
|
dim0_stop = dim0_start + const_utils.make_tensor(1)
|
|
987
1191
|
start = (dim0_start, dim1_start)
|
|
988
1192
|
stop = (dim0_stop, dim1_stop)
|
|
@@ -994,7 +1198,7 @@ def tensor_copy_slice_from_tuple(data, tuple_index, value):
|
|
|
994
1198
|
if dim1_size > data_shape[1]:
|
|
995
1199
|
dim1_size = data_shape[1:2]
|
|
996
1200
|
value_shape = P.Concat(-1)((dim1_size, data_shape[2:]))
|
|
997
|
-
value =
|
|
1201
|
+
value = ops.broadcast_to(value, value_shape)
|
|
998
1202
|
return copy_slice(data, value.astype(data.dtype), start_tensor, stop_tensor, step_tensor)
|
|
999
1203
|
|
|
1000
1204
|
|
|
@@ -1003,8 +1207,8 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
|
|
1003
1207
|
op_name = const_utils.TENSOR_SETITEM
|
|
1004
1208
|
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
|
1005
1209
|
|
|
1006
|
-
if const_utils.use_copy_slice(tuple_index):
|
|
1007
|
-
if
|
|
1210
|
+
if const_utils.use_copy_slice(tuple_index) and not const_utils.is_ascend():
|
|
1211
|
+
if F.is_sequence_value_unknown(F.shape(data)):
|
|
1008
1212
|
return tensor_copy_slice_from_tuple(data, tuple_index, value)
|
|
1009
1213
|
dim1_start, dim1_stop, _ = const_utils.normalize_slice(tuple_index[1], data.shape[1])
|
|
1010
1214
|
if dim1_stop - dim1_start <= 0:
|
|
@@ -1024,7 +1228,6 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
|
|
1024
1228
|
if len(tuple_index) == 1:
|
|
1025
1229
|
data[tuple_index[0]] = value
|
|
1026
1230
|
return data
|
|
1027
|
-
|
|
1028
1231
|
indexes_types = hyper_map(toptypeof, tuple_index)
|
|
1029
1232
|
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
|
|
1030
1233
|
|
|
@@ -1058,14 +1261,20 @@ def tensor_setitem_by_number_with_sequence(data, index, value):
|
|
|
1058
1261
|
def tensor_setitem_by_number_with_tensor(data, index, value):
|
|
1059
1262
|
"""Assigns the tensor by number with tensor value."""
|
|
1060
1263
|
data_shape = F.shape(data)
|
|
1061
|
-
if
|
|
1062
|
-
index =
|
|
1264
|
+
if F.is_sequence_value_unknown(data_shape):
|
|
1265
|
+
index = _scalar_to_tensor(index)
|
|
1063
1266
|
index = F.expand_dims(index, -1)
|
|
1064
1267
|
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value)
|
|
1065
1268
|
|
|
1269
|
+
dim_size = data_shape[0]
|
|
1270
|
+
if index < -dim_size or index >= dim_size:
|
|
1271
|
+
raise IndexError(f'index {index} is out of bounds for axis 0 with size {dim_size}')
|
|
1066
1272
|
index = const_utils.int_to_index(index, data_shape)
|
|
1067
1273
|
value_shape = const_utils.tuple_slice(F.shape(index), None, -1)
|
|
1068
1274
|
value = _broadcast(value_shape, value.astype(F.dtype(data)))
|
|
1275
|
+
if is_parameter(data):
|
|
1276
|
+
F.scatter_nd_update(data, index, value)
|
|
1277
|
+
return data
|
|
1069
1278
|
return F.tensor_scatter_update(data, index, value)
|
|
1070
1279
|
|
|
1071
1280
|
|
|
@@ -1073,7 +1282,7 @@ def tensor_setitem_by_ellipsis_with_number(data, value):
|
|
|
1073
1282
|
"""Assigns the tensor by ellipsis with number value."""
|
|
1074
1283
|
data_shape = F.shape(data)
|
|
1075
1284
|
data_dtype = F.dtype(data)
|
|
1076
|
-
if
|
|
1285
|
+
if F.is_sequence_value_unknown(data_shape):
|
|
1077
1286
|
value = F.fill(F.dtype(data), (), value)
|
|
1078
1287
|
return tensor_setitem_by_ellipsis_with_tensor(data, value)
|
|
1079
1288
|
return F.fill(data_dtype, data_shape, value)
|
|
@@ -1085,9 +1294,9 @@ def tensor_setitem_by_ellipsis_with_tensor(data, value):
|
|
|
1085
1294
|
data_dtype = F.dtype(data)
|
|
1086
1295
|
value = value.astype(data_dtype)
|
|
1087
1296
|
|
|
1088
|
-
if
|
|
1297
|
+
if F.is_sequence_value_unknown(data_shape):
|
|
1089
1298
|
data_shape = F.dyn_shape(data)
|
|
1090
|
-
data =
|
|
1299
|
+
data = ops.broadcast_to(value, data_shape)
|
|
1091
1300
|
return data
|
|
1092
1301
|
value_shape = F.shape(value)
|
|
1093
1302
|
source_shape = const_utils.get_source_shape(data_shape, value_shape)
|
|
@@ -1115,9 +1324,9 @@ def tensor_setitem_by_bool(data, index, value):
|
|
|
1115
1324
|
elif isinstance(value, float):
|
|
1116
1325
|
value = const_utils.make_tensor(value, mstype.float32)
|
|
1117
1326
|
|
|
1118
|
-
if
|
|
1327
|
+
if F.is_sequence_value_unknown(data_shape) and index:
|
|
1119
1328
|
data_shape = F.dyn_shape(data)
|
|
1120
|
-
data =
|
|
1329
|
+
data = ops.broadcast_to(value, data_shape)
|
|
1121
1330
|
return data
|
|
1122
1331
|
value_shape = F.shape(value)
|
|
1123
1332
|
source_shape = const_utils.get_source_shape(data_shape, value_shape)
|
|
@@ -1143,6 +1352,8 @@ def format_list_indices(list_indices, length):
|
|
|
1143
1352
|
# If eyery element in list is bool, it's treated as 1-D bool tensor.
|
|
1144
1353
|
# If every element in list is int(not all bool), it's treated as int tensor.
|
|
1145
1354
|
if const_utils.judge_indexes_types(indices_types, mstype.int_type + (mstype.bool_,)):
|
|
1355
|
+
if not F.isconstant(length):
|
|
1356
|
+
return const_utils.sequence_to_index(list_indices, None)
|
|
1146
1357
|
return const_utils.sequence_to_index(list_indices, length)
|
|
1147
1358
|
# If list contains other types(.../list/tuple/None), it's treated as a tuple
|
|
1148
1359
|
return const_utils.deep_tuple(list_indices)
|
|
@@ -1162,10 +1373,34 @@ def format_tuple_indices(tuple_indices):
|
|
|
1162
1373
|
return res
|
|
1163
1374
|
|
|
1164
1375
|
|
|
1376
|
+
@_primexpr
|
|
1377
|
+
def remove_expanded_dims_parse_bool_tensor_index(index_out, indices_out, shapes, cur_dim):
|
|
1378
|
+
""" Parse bool tensor index """
|
|
1379
|
+
index_out = index_out.nonzero()
|
|
1380
|
+
if index_out.shape[0] == 0:
|
|
1381
|
+
return None, shapes, cur_dim
|
|
1382
|
+
for i in range(index_out.shape[1]):
|
|
1383
|
+
out = index_out[:, i]
|
|
1384
|
+
indices_out += (out,)
|
|
1385
|
+
shapes.append(F.shape(out))
|
|
1386
|
+
cur_dim += 1
|
|
1387
|
+
return indices_out, shapes, cur_dim
|
|
1388
|
+
|
|
1389
|
+
|
|
1390
|
+
def remove_expanded_dims_parse_tensor_index(index_out, indices_out, shapes, cur_dim):
|
|
1391
|
+
""" Parse tensor index """
|
|
1392
|
+
if index_out.dtype == mstype.bool_:
|
|
1393
|
+
return remove_expanded_dims_parse_bool_tensor_index(index_out, indices_out, shapes, cur_dim)
|
|
1394
|
+
indices_out += (index_out,)
|
|
1395
|
+
shapes.append(F.shape(index_out))
|
|
1396
|
+
cur_dim += 1
|
|
1397
|
+
return indices_out, shapes, cur_dim
|
|
1398
|
+
|
|
1399
|
+
|
|
1165
1400
|
def remove_expanded_dims(tuple_index, data_shape, value):
|
|
1166
1401
|
"""Removes expanded dimensions in tuple_index and value."""
|
|
1167
1402
|
not_expanded_dim = ()
|
|
1168
|
-
shapes =
|
|
1403
|
+
shapes = []
|
|
1169
1404
|
has_true = False
|
|
1170
1405
|
has_false = False
|
|
1171
1406
|
has_sequence = False
|
|
@@ -1192,11 +1427,12 @@ def remove_expanded_dims(tuple_index, data_shape, value):
|
|
|
1192
1427
|
idx_advanced = 0
|
|
1193
1428
|
idx_tensor = i
|
|
1194
1429
|
if isinstance(index_out, Tensor):
|
|
1195
|
-
|
|
1430
|
+
indices_out, shapes, cur_dim = \
|
|
1431
|
+
remove_expanded_dims_parse_tensor_index(index_out, indices_out, shapes, cur_dim)
|
|
1432
|
+
if indices_out is None:
|
|
1433
|
+
return False, value, 0
|
|
1434
|
+
if index_out.dtype != mstype.bool_ and F.rank(index_out) > 0:
|
|
1196
1435
|
has_sequence = True
|
|
1197
|
-
indices_out += (index_out,)
|
|
1198
|
-
shapes += (F.shape(index_out),)
|
|
1199
|
-
cur_dim += 1
|
|
1200
1436
|
has_true = has_true or index_out is True
|
|
1201
1437
|
has_false = has_false or index_out is False
|
|
1202
1438
|
else:
|
|
@@ -1229,11 +1465,21 @@ def format_index(idx, data_shape, cur_dim):
|
|
|
1229
1465
|
elif isinstance(idx, int) and not isinstance(idx, bool):
|
|
1230
1466
|
idx = const_utils.make_tensor(idx, mstype.int64, None, data_shape[cur_dim])
|
|
1231
1467
|
elif isinstance(idx, Tensor):
|
|
1232
|
-
|
|
1233
|
-
|
|
1468
|
+
tensor_dtype = const_utils.get_index_tensor_dtype(idx.dtype)
|
|
1469
|
+
if tensor_dtype == const_utils.INT_:
|
|
1470
|
+
idx = F.select(idx < 0, idx + data_shape[cur_dim], idx)
|
|
1471
|
+
elif tensor_dtype == const_utils.BOOL_:
|
|
1472
|
+
# index with tensor(bool) type is processed in remove_expanded_dims()
|
|
1473
|
+
pass
|
|
1234
1474
|
return idx
|
|
1235
1475
|
|
|
1236
1476
|
|
|
1477
|
+
@_primexpr
|
|
1478
|
+
def _check_shape_mul(shape):
|
|
1479
|
+
if F.shape_mul(shape) == 0:
|
|
1480
|
+
raise ValueError('zero-size tensors are not supported.')
|
|
1481
|
+
|
|
1482
|
+
|
|
1237
1483
|
def reduce_(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None, where=True, dtype=None):
|
|
1238
1484
|
"""
|
|
1239
1485
|
Applies comparison based on cmp_fn and reduction based on reduce_fn.
|
|
@@ -1250,8 +1496,7 @@ def reduce_(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None,
|
|
|
1250
1496
|
not isinstance(initial, (int, float, bool, Tensor))):
|
|
1251
1497
|
const_utils.raise_type_error('initial must be scalar')
|
|
1252
1498
|
|
|
1253
|
-
|
|
1254
|
-
const_utils.raise_value_error('zero-size tensors are not supported.')
|
|
1499
|
+
_check_shape_mul(shape)
|
|
1255
1500
|
|
|
1256
1501
|
if initial is not None:
|
|
1257
1502
|
if isinstance(initial, Tensor):
|