mindspore 2.0.0a0__cp39-cp39-win_amd64.whl → 2.0.0rc1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -2
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +16 -1
- mindspore/_extends/parse/parser.py +107 -22
- mindspore/_extends/parse/resources.py +0 -7
- mindspore/_extends/parse/standard_method.py +885 -413
- mindspore/amp.py +52 -57
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +38 -20
- mindspore/boost/dim_reduce.py +3 -3
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -6
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +41 -7
- mindspore/common/api.py +215 -141
- mindspore/common/dtype.py +8 -1
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +4 -2
- mindspore/common/jit_config.py +17 -13
- mindspore/common/mutable.py +33 -13
- mindspore/common/parameter.py +23 -21
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +62 -41
- mindspore/common/tensor.py +852 -1154
- mindspore/communication/__init__.py +2 -2
- mindspore/communication/_comm_helper.py +11 -4
- mindspore/communication/management.py +22 -21
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +201 -23
- mindspore/dataset/__init__.py +6 -6
- mindspore/dataset/audio/__init__.py +7 -7
- mindspore/dataset/audio/transforms.py +670 -30
- mindspore/dataset/audio/utils.py +47 -4
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +210 -14
- mindspore/dataset/core/validator_helpers.py +2 -2
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +322 -66
- mindspore/dataset/engine/datasets_audio.py +80 -76
- mindspore/dataset/engine/datasets_standard_format.py +51 -38
- mindspore/dataset/engine/datasets_text.py +232 -118
- mindspore/dataset/engine/datasets_user_defined.py +41 -17
- mindspore/dataset/engine/datasets_vision.py +746 -225
- mindspore/dataset/engine/graphdata.py +75 -10
- mindspore/dataset/engine/iterators.py +45 -5
- mindspore/dataset/engine/offload.py +48 -28
- mindspore/dataset/engine/validators.py +117 -8
- mindspore/dataset/text/__init__.py +6 -5
- mindspore/dataset/text/transforms.py +86 -3
- mindspore/dataset/text/utils.py +6 -4
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +3 -2
- mindspore/dataset/transforms/c_transforms.py +1 -1
- mindspore/dataset/transforms/transforms.py +2 -2
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +2 -3
- mindspore/dataset/vision/c_transforms.py +9 -9
- mindspore/dataset/vision/py_transforms.py +5 -5
- mindspore/dataset/vision/py_transforms_util.py +2 -0
- mindspore/dataset/vision/transforms.py +160 -161
- mindspore/dataset/vision/utils.py +3 -3
- mindspore/experimental/map_parameter.py +38 -26
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +44 -9
- mindspore/include/api/delegate.h +1 -1
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_parallel_runner.h +2 -2
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +19 -3
- mindspore/include/api/types.h +3 -3
- mindspore/include/dataset/constants.h +7 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filereader.py +18 -0
- mindspore/mindrecord/filewriter.py +197 -34
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
- mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
- mindspore/mindrecord/tools/csv_to_mr.py +3 -3
- mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/__init__.py +0 -4
- mindspore/nn/cell.py +204 -132
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +7 -6
- mindspore/nn/layer/__init__.py +5 -4
- mindspore/nn/layer/activation.py +40 -89
- mindspore/nn/layer/basic.py +255 -624
- mindspore/nn/layer/channel_shuffle.py +7 -6
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +41 -4
- mindspore/nn/layer/conv.py +64 -28
- mindspore/nn/layer/dense.py +9 -8
- mindspore/nn/layer/embedding.py +27 -25
- mindspore/nn/layer/image.py +53 -46
- mindspore/nn/layer/math.py +97 -105
- mindspore/nn/layer/normalization.py +117 -86
- mindspore/nn/layer/padding.py +185 -95
- mindspore/nn/layer/pooling.py +817 -414
- mindspore/nn/layer/rnn_cells.py +10 -15
- mindspore/nn/layer/rnns.py +37 -38
- mindspore/nn/layer/thor_layer.py +11 -12
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +5 -4
- mindspore/nn/loss/loss.py +334 -199
- mindspore/nn/optim/ada_grad.py +6 -6
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +4 -5
- mindspore/nn/optim/adam.py +126 -62
- mindspore/nn/optim/adamax.py +3 -4
- mindspore/nn/optim/adasum.py +6 -6
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +67 -38
- mindspore/nn/optim/lamb.py +4 -5
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +43 -4
- mindspore/nn/optim/momentum.py +6 -5
- mindspore/nn/optim/optimizer.py +3 -1
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +1 -1
- mindspore/nn/optim/rprop.py +8 -9
- mindspore/nn/optim/sgd.py +19 -13
- mindspore/nn/optim/thor.py +10 -15
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +4 -4
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +9 -15
- mindspore/nn/probability/distribution/bernoulli.py +3 -3
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +5 -7
- mindspore/nn/probability/distribution/cauchy.py +3 -3
- mindspore/nn/probability/distribution/distribution.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +3 -3
- mindspore/nn/probability/distribution/half_normal.py +15 -11
- mindspore/nn/probability/distribution/laplace.py +16 -13
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/normal.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/student_t.py +20 -15
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +27 -10
- mindspore/nn/wrap/grad_reducer.py +2 -2
- mindspore/nn/wrap/loss_scale.py +40 -24
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +35 -30
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +22 -19
- mindspore/numpy/utils.py +1 -1
- mindspore/numpy/utils_const.py +108 -58
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +86 -117
- mindspore/ops/_grad/grad_base.py +23 -1
- mindspore/ops/_grad/grad_clip_ops.py +2 -3
- mindspore/ops/_grad/grad_comm_ops.py +34 -24
- mindspore/ops/_grad/grad_implementations.py +9 -45
- mindspore/ops/_grad/grad_inner_ops.py +47 -4
- mindspore/ops/_grad/grad_math_ops.py +142 -117
- mindspore/ops/_grad/grad_nn_ops.py +71 -165
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +7 -6
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
- mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
- mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
- mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -611
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_register_for_op.py +1 -0
- mindspore/ops/_utils/__init__.py +1 -2
- mindspore/ops/_utils/utils.py +19 -40
- mindspore/ops/_vmap/vmap_array_ops.py +116 -38
- mindspore/ops/_vmap/vmap_base.py +16 -9
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
- mindspore/ops/_vmap/vmap_image_ops.py +12 -5
- mindspore/ops/_vmap/vmap_math_ops.py +46 -5
- mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
- mindspore/ops/_vmap/vmap_random_ops.py +1 -1
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
- mindspore/ops/composite/__init__.py +7 -8
- mindspore/ops/composite/base.py +101 -47
- mindspore/ops/composite/math_ops.py +188 -158
- mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
- mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
- mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
- mindspore/ops/function/__init__.py +152 -8
- mindspore/ops/function/array_func.py +2555 -674
- mindspore/ops/function/clip_func.py +209 -13
- mindspore/ops/function/debug_func.py +2 -2
- mindspore/ops/function/grad/__init__.py +2 -1
- mindspore/ops/function/grad/grad_func.py +147 -62
- mindspore/ops/function/image_func.py +54 -38
- mindspore/ops/function/linalg_func.py +167 -16
- mindspore/ops/function/math_func.py +4849 -1492
- mindspore/ops/function/nn_func.py +2573 -988
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +3 -3
- mindspore/ops/function/random_func.py +790 -73
- mindspore/ops/function/sparse_func.py +98 -78
- mindspore/ops/function/sparse_unary_func.py +54 -53
- mindspore/ops/function/spectral_func.py +27 -24
- mindspore/ops/function/vmap_func.py +22 -2
- mindspore/ops/functional.py +97 -37
- mindspore/ops/op_info_register.py +70 -28
- mindspore/ops/operations/__init__.py +47 -14
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +5 -5
- mindspore/ops/operations/_grad_ops.py +276 -187
- mindspore/ops/operations/_inner_ops.py +319 -113
- mindspore/ops/operations/_ms_kernel.py +10 -8
- mindspore/ops/operations/_ocr_ops.py +9 -9
- mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
- mindspore/ops/operations/_quant_ops.py +137 -102
- mindspore/ops/operations/_rl_inner_ops.py +121 -60
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1004 -2
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +801 -466
- mindspore/ops/operations/comm_ops.py +51 -49
- mindspore/ops/operations/control_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +123 -44
- mindspore/ops/operations/debug_ops.py +24 -24
- mindspore/ops/operations/image_ops.py +240 -153
- mindspore/ops/operations/inner_ops.py +34 -50
- mindspore/ops/operations/linalg_ops.py +31 -9
- mindspore/ops/operations/math_ops.py +988 -757
- mindspore/ops/operations/nn_ops.py +965 -819
- mindspore/ops/operations/other_ops.py +51 -40
- mindspore/ops/operations/random_ops.py +204 -122
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +254 -93
- mindspore/ops/operations/spectral_ops.py +35 -3
- mindspore/ops/primitive.py +111 -9
- mindspore/parallel/_auto_parallel_context.py +189 -83
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +99 -7
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +7 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
- mindspore/parallel/_utils.py +1 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +37 -34
- mindspore/parallel/shard.py +17 -18
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +69 -47
- mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
- mindspore/profiler/parser/base_timeline_generator.py +49 -56
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
- mindspore/profiler/parser/hwts_log_parser.py +1 -1
- mindspore/profiler/parser/integrator.py +15 -14
- mindspore/profiler/parser/minddata_analyzer.py +2 -2
- mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +2 -1
- mindspore/profiler/profiling.py +218 -186
- mindspore/rewrite/__init__.py +3 -1
- mindspore/rewrite/api/node.py +1 -114
- mindspore/rewrite/api/node_type.py +3 -0
- mindspore/rewrite/api/pattern_engine.py +31 -1
- mindspore/rewrite/api/scoped_value.py +4 -4
- mindspore/rewrite/api/symbol_tree.py +3 -78
- mindspore/rewrite/api/tree_node_helper.py +1 -1
- mindspore/rewrite/ast_creator_register.py +1 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -2
- mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
- mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
- mindspore/rewrite/namespace.py +0 -2
- mindspore/rewrite/node.py +157 -11
- mindspore/rewrite/parsers/assign_parser.py +231 -53
- mindspore/rewrite/parsers/class_def_parser.py +187 -109
- mindspore/rewrite/parsers/for_parser.py +24 -14
- mindspore/rewrite/parsers/function_def_parser.py +21 -4
- mindspore/rewrite/parsers/if_parser.py +6 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +256 -133
- mindspore/rewrite/symbol_tree_builder.py +38 -1
- mindspore/run_check/_check_version.py +69 -63
- mindspore/run_check/run_check.py +2 -1
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +1 -1
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +273 -102
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +2 -2
- mindspore/train/callback/_checkpoint.py +3 -3
- mindspore/train/callback/_early_stop.py +3 -3
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +29 -31
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +3 -3
- mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
- mindspore/train/callback/_summary_collector.py +23 -16
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +15 -3
- mindspore/train/dataset_helper.py +10 -15
- mindspore/train/loss_scale_manager.py +8 -11
- mindspore/train/metrics/__init__.py +1 -1
- mindspore/train/metrics/bleu_score.py +1 -1
- mindspore/train/metrics/confusion_matrix.py +1 -1
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/dice.py +2 -2
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +4 -3
- mindspore/train/metrics/mean_surface_distance.py +2 -2
- mindspore/train/metrics/occlusion_sensitivity.py +1 -1
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +1 -1
- mindspore/train/metrics/recall.py +1 -1
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +45 -28
- mindspore/train/serialization.py +295 -188
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -13
- mindspore/train/train_thor/convert_utils.py +2 -2
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
- mindspore/compression/__init__.py +0 -19
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -515
- mindspore/compression/quant/__init__.py +0 -28
- mindspore/compression/quant/qat.py +0 -634
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -140
- mindspore/nn/probability/dpn/vae/vae.py +0 -124
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
- mindspore/ops/composite/array_ops.py +0 -241
- mindspore/ops/composite/clip_ops.py +0 -134
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -14,21 +14,19 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""constexpr util"""
|
|
16
16
|
|
|
17
|
-
from itertools import compress, zip_longest
|
|
18
|
-
from functools import partial
|
|
19
|
-
from collections import deque
|
|
20
17
|
import operator
|
|
18
|
+
from functools import partial
|
|
19
|
+
from itertools import compress
|
|
21
20
|
|
|
22
21
|
import numpy as np
|
|
23
|
-
|
|
24
|
-
from mindspore.ops.primitive import constexpr
|
|
25
|
-
from mindspore import log as logger
|
|
22
|
+
from mindspore import _checkparam as validator
|
|
26
23
|
from mindspore.common import dtype as mstype
|
|
27
|
-
from mindspore.common.tensor import Tensor
|
|
28
24
|
from mindspore.common._register_for_tensor import tensor_operator_registry
|
|
29
|
-
from mindspore.
|
|
30
|
-
from mindspore._checkparam import Validator as validator
|
|
25
|
+
from mindspore.common.tensor import Tensor
|
|
31
26
|
from mindspore.ops import operations as P
|
|
27
|
+
from mindspore.ops.primitive import constexpr, _primexpr
|
|
28
|
+
from mindspore import log as logger
|
|
29
|
+
from mindspore import context
|
|
32
30
|
|
|
33
31
|
ALL_TENSOR = 0
|
|
34
32
|
NO_TENSOR = 1
|
|
@@ -117,9 +115,9 @@ def make_empty_slice():
|
|
|
117
115
|
|
|
118
116
|
|
|
119
117
|
@constexpr
|
|
120
|
-
def _deep_list(array_like, dim_size
|
|
118
|
+
def _deep_list(array_like, dim_size=None):
|
|
121
119
|
"""convert nested tuple/list mixtures to pure nested list"""
|
|
122
|
-
if dim_size
|
|
120
|
+
if dim_size is not None:
|
|
123
121
|
array_like = check_range(array_like, dim_size)
|
|
124
122
|
if isinstance(array_like, (list, tuple)):
|
|
125
123
|
return list(map(lambda x: _deep_list(x, dim_size), array_like))
|
|
@@ -160,7 +158,7 @@ def _deep_tensor_to_nparray(array_like):
|
|
|
160
158
|
|
|
161
159
|
@constexpr
|
|
162
160
|
def check_range(x, dim_size):
|
|
163
|
-
if dim_size
|
|
161
|
+
if dim_size is None:
|
|
164
162
|
return x
|
|
165
163
|
if isinstance(x, int) and not isinstance(x, bool):
|
|
166
164
|
if x >= dim_size or x < -dim_size:
|
|
@@ -170,7 +168,7 @@ def check_range(x, dim_size):
|
|
|
170
168
|
|
|
171
169
|
|
|
172
170
|
@constexpr
|
|
173
|
-
def make_tensor(a, dtype=mstype.int64, data_shape=None, dim_size
|
|
171
|
+
def make_tensor(a, dtype=mstype.int64, data_shape=None, dim_size=None):
|
|
174
172
|
"""
|
|
175
173
|
Converts the input to tensor.
|
|
176
174
|
|
|
@@ -192,9 +190,9 @@ def make_tensor(a, dtype=mstype.int64, data_shape=None, dim_size=-1):
|
|
|
192
190
|
return Tensor(np.zeros(data_shape), dtype)
|
|
193
191
|
|
|
194
192
|
if not isinstance(a, (list, tuple, int, float, bool)):
|
|
195
|
-
raise TypeError("
|
|
193
|
+
raise TypeError(f"Input data must be `int`, `float`, `bool`, `list` or `tuple`, but got {a}")
|
|
196
194
|
|
|
197
|
-
if dim_size
|
|
195
|
+
if dim_size is not None:
|
|
198
196
|
a = check_range(a, dim_size)
|
|
199
197
|
|
|
200
198
|
if isinstance(a, (list, tuple)):
|
|
@@ -215,7 +213,6 @@ def make_tensor(a, dtype=mstype.int64, data_shape=None, dim_size=-1):
|
|
|
215
213
|
tensor_operator_registry.register('make_tensor', make_tensor)
|
|
216
214
|
|
|
217
215
|
|
|
218
|
-
@constexpr
|
|
219
216
|
def judge_data_dim(data_dim, min_data_dim=0, max_data_dim=8):
|
|
220
217
|
"""Judges whether the data dim is valid."""
|
|
221
218
|
if data_dim < min_data_dim or data_dim > max_data_dim:
|
|
@@ -223,21 +220,11 @@ def judge_data_dim(data_dim, min_data_dim=0, max_data_dim=8):
|
|
|
223
220
|
f"{max_data_dim}], but got '{data_dim}'.")
|
|
224
221
|
|
|
225
222
|
|
|
226
|
-
@constexpr
|
|
227
223
|
def get_source_shape(data_shape, value_shape):
|
|
228
224
|
"""Returns the shape of value that will be used to broadcast against data."""
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
if j not in (1, i):
|
|
233
|
-
cannot_broadcast = True
|
|
234
|
-
for i in range(len(value_shape) - len(data_shape)):
|
|
235
|
-
source_shape = data_shape
|
|
236
|
-
if value_shape[i] != 1:
|
|
237
|
-
cannot_broadcast = True
|
|
238
|
-
if cannot_broadcast:
|
|
239
|
-
raise ValueError(f'could not broadcast input array from shape {value_shape} to {data_shape}')
|
|
240
|
-
return source_shape
|
|
225
|
+
if len(value_shape) > len(data_shape):
|
|
226
|
+
return data_shape
|
|
227
|
+
return value_shape
|
|
241
228
|
|
|
242
229
|
|
|
243
230
|
@constexpr
|
|
@@ -407,10 +394,21 @@ def slice2indices(input_slice, shape):
|
|
|
407
394
|
start, stop, step = normalize_slice(input_slice, shape[0])
|
|
408
395
|
if check_slice_empty(start, stop, step):
|
|
409
396
|
return False
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
397
|
+
ndim = len(shape)
|
|
398
|
+
mesh = list()
|
|
399
|
+
grids = [P.Range()(P.Fill()(mstype.int64, (), start), P.Fill()(
|
|
400
|
+
mstype.int64, (), stop), P.Fill()(mstype.int64, (), step))]
|
|
401
|
+
grids += [P.Range()(Tensor(0, mstype.int64), P.Fill()(mstype.int64, (), dim_size),
|
|
402
|
+
Tensor(1, mstype.int64)) for dim_size in shape[1:]]
|
|
403
|
+
for j, grid in enumerate(grids):
|
|
404
|
+
mesh.append(P.Reshape()(grid, tuple(
|
|
405
|
+
[grid.size if j == t else 1 for t in range(ndim)])))
|
|
406
|
+
shapes = map(P.Shape(), mesh)
|
|
407
|
+
out_shape = infer_out_shape(*shapes)
|
|
408
|
+
mesh_arrays = list()
|
|
409
|
+
for arr in mesh:
|
|
410
|
+
mesh_arrays.append(P.BroadcastTo(out_shape)(arr))
|
|
411
|
+
return P.Stack(-1)(mesh_arrays)
|
|
414
412
|
|
|
415
413
|
|
|
416
414
|
@constexpr
|
|
@@ -422,7 +420,7 @@ def check_indices(indices_size, index):
|
|
|
422
420
|
return indices_size
|
|
423
421
|
|
|
424
422
|
|
|
425
|
-
@
|
|
423
|
+
@_primexpr
|
|
426
424
|
def check_indices_value_size(indices_size, value_size):
|
|
427
425
|
"""Checks if the sizes are already matched."""
|
|
428
426
|
if value_size < 1:
|
|
@@ -479,35 +477,61 @@ def check_tensors_dtype_same(data_dtype, value_dtype, op_name):
|
|
|
479
477
|
f"is not consistent with assigned tensor data type {data_dtype}.")
|
|
480
478
|
|
|
481
479
|
|
|
480
|
+
@constexpr
|
|
481
|
+
def get_broadcast_shape(x_shape, y_shape, prim_name):
|
|
482
|
+
"""Get broadcast shape from input shapes."""
|
|
483
|
+
if x_shape is None or y_shape is None:
|
|
484
|
+
raise ValueError("get_broadcast_shape has dynamic rank input")
|
|
485
|
+
if None in x_shape or None in y_shape:
|
|
486
|
+
raise ValueError("get_broadcast_shape has dynamic shape input")
|
|
487
|
+
if x_shape == y_shape:
|
|
488
|
+
return x_shape
|
|
489
|
+
x_len = len(x_shape)
|
|
490
|
+
y_len = len(y_shape)
|
|
491
|
+
length = x_len if x_len < y_len else y_len
|
|
492
|
+
broadcast_shape_back = []
|
|
493
|
+
|
|
494
|
+
for i in range(-length, 0):
|
|
495
|
+
if x_shape[i] == 1:
|
|
496
|
+
broadcast_shape_back.append(y_shape[i])
|
|
497
|
+
elif y_shape[i] == 1:
|
|
498
|
+
broadcast_shape_back.append(x_shape[i])
|
|
499
|
+
elif x_shape[i] == y_shape[i]:
|
|
500
|
+
broadcast_shape_back.append(x_shape[i])
|
|
501
|
+
else:
|
|
502
|
+
raise ValueError(f"For '{prim_name}', x.shape and y.shape need to "
|
|
503
|
+
f"broadcast. The value of x.shape[{i}] or y.shape[{i}]"
|
|
504
|
+
f" must be 1 or -1 when they are not the same, "
|
|
505
|
+
f"but got x.shape = {x_shape} "
|
|
506
|
+
f"and y.shape = {y_shape}.")
|
|
507
|
+
|
|
508
|
+
broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
|
|
509
|
+
broadcast_shape = list(broadcast_shape_front) + broadcast_shape_back
|
|
510
|
+
return broadcast_shape
|
|
511
|
+
|
|
512
|
+
|
|
482
513
|
@constexpr
|
|
483
514
|
def generate_broadcast_shape(shapes, op_name):
|
|
484
515
|
"""Generate broadcast shape for a tuple of shape."""
|
|
485
516
|
if not shapes:
|
|
486
517
|
return ()
|
|
487
518
|
broadcast_shape = shapes[0]
|
|
488
|
-
for
|
|
489
|
-
|
|
490
|
-
try:
|
|
491
|
-
broadcast_shape = op_utils.get_broadcast_shape(
|
|
492
|
-
broadcast_shape, shape, op_name)
|
|
493
|
-
except ValueError as ex:
|
|
494
|
-
raise IndexError(ex)
|
|
519
|
+
for shape in shapes:
|
|
520
|
+
broadcast_shape = get_broadcast_shape(tuple(broadcast_shape), shape, op_name)
|
|
495
521
|
return tuple(broadcast_shape)
|
|
496
522
|
|
|
497
523
|
|
|
498
|
-
@
|
|
524
|
+
@_primexpr
|
|
499
525
|
def check_two_shapes_need_broadcast(shape_x, shape_y):
|
|
500
526
|
"""Check shape_y needs to be broadcast to shape_x."""
|
|
501
|
-
if any(j not in (i, 1) for i, j in zip(reversed(shape_x), reversed(shape_y))):
|
|
502
|
-
raise ValueError(f"{shape_y} could not broadcast with {shape_x}.")
|
|
503
527
|
return shape_y != shape_x
|
|
504
528
|
|
|
505
529
|
|
|
506
|
-
@
|
|
530
|
+
@_primexpr
|
|
507
531
|
def compute_multiples(origin_shape, broadcast_shape):
|
|
508
532
|
"""Compute multiples between origin shape with broadcast shape."""
|
|
509
533
|
len_gap = len(broadcast_shape) - len(origin_shape)
|
|
510
|
-
return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape))
|
|
534
|
+
return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], tuple(origin_shape)))
|
|
511
535
|
|
|
512
536
|
|
|
513
537
|
@constexpr
|
|
@@ -517,7 +541,7 @@ def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_ty
|
|
|
517
541
|
updates_shape = indices_shape + data_shape[1:]
|
|
518
542
|
else:
|
|
519
543
|
updates_shape = indices_shape[:-1] + data_shape[indices_shape[-1]:]
|
|
520
|
-
return
|
|
544
|
+
return P.Fill()(data_dtype, updates_shape, value)
|
|
521
545
|
|
|
522
546
|
|
|
523
547
|
def generate_updates_shape(data_shape, index_shape, op_type, is_dynamic):
|
|
@@ -578,12 +602,6 @@ def _judge_order_continuous(order_sequence):
|
|
|
578
602
|
@constexpr
|
|
579
603
|
def scalar_in_sequence(x, y):
|
|
580
604
|
"""Determine whether the scalar in the sequence."""
|
|
581
|
-
if x is None:
|
|
582
|
-
raise ValueError("Judge scalar in tuple or list require scalar and sequence must be constant, "
|
|
583
|
-
"but the scalar is not.")
|
|
584
|
-
if y is None:
|
|
585
|
-
raise ValueError("Judge scalar in tuple or list require scalar and sequence must be constant, "
|
|
586
|
-
"but the sequence is not.")
|
|
587
605
|
return x in y
|
|
588
606
|
|
|
589
607
|
|
|
@@ -606,7 +624,7 @@ def check_number_index_type(number):
|
|
|
606
624
|
.format(number, type(number)))
|
|
607
625
|
|
|
608
626
|
|
|
609
|
-
@
|
|
627
|
+
@_primexpr
|
|
610
628
|
def get_stride_info_from_slice(data_shape, slice_index):
|
|
611
629
|
"""Get stride info from a python slice"""
|
|
612
630
|
begin, end, step = get_slice_stride(slice_index, data_shape[0])
|
|
@@ -707,7 +725,7 @@ def unpack(x):
|
|
|
707
725
|
return x
|
|
708
726
|
|
|
709
727
|
|
|
710
|
-
@
|
|
728
|
+
@_primexpr
|
|
711
729
|
def normalize_start(start, dim_size):
|
|
712
730
|
"""
|
|
713
731
|
Normalize `start` according to the number of dimensions (`dim_size`).
|
|
@@ -715,24 +733,24 @@ def normalize_start(start, dim_size):
|
|
|
715
733
|
"""
|
|
716
734
|
if start is None:
|
|
717
735
|
return 0
|
|
718
|
-
if dim_size
|
|
736
|
+
if dim_size is None:
|
|
719
737
|
return start
|
|
720
738
|
if start < 0:
|
|
721
739
|
return 0 if start < -dim_size else start % dim_size
|
|
722
740
|
return start if start < dim_size else dim_size
|
|
723
741
|
|
|
724
742
|
|
|
725
|
-
@
|
|
743
|
+
@_primexpr
|
|
726
744
|
def normalize_stop(stop, dim_size):
|
|
727
745
|
"""
|
|
728
746
|
Normalize `stop` according to the number of dimensions (`dim_size`).
|
|
729
747
|
If the number of dimensions is not given, return the original input directly.
|
|
730
748
|
"""
|
|
731
|
-
if stop is None and dim_size
|
|
749
|
+
if stop is None and dim_size is None:
|
|
732
750
|
raise IndexError("Not Support stop is None when dim is dynamic")
|
|
733
751
|
if stop is None:
|
|
734
752
|
return dim_size
|
|
735
|
-
if dim_size
|
|
753
|
+
if dim_size is None:
|
|
736
754
|
return stop
|
|
737
755
|
if stop < 0:
|
|
738
756
|
return 0 if stop < -dim_size else stop % dim_size
|
|
@@ -748,11 +766,9 @@ def get_step_from_slice(input_slice):
|
|
|
748
766
|
return step
|
|
749
767
|
|
|
750
768
|
|
|
751
|
-
@
|
|
769
|
+
@_primexpr
|
|
752
770
|
def normalize_slice(input_slice, dim_size):
|
|
753
771
|
"""Normalizes start, stop, step in a slice."""
|
|
754
|
-
start = normalize_start(input_slice.start, dim_size)
|
|
755
|
-
stop = normalize_stop(input_slice.stop, dim_size)
|
|
756
772
|
step = input_slice.step
|
|
757
773
|
if step is None:
|
|
758
774
|
step = 1
|
|
@@ -771,7 +787,6 @@ def tuple_slice(tup, start, end):
|
|
|
771
787
|
return tup[start:end]
|
|
772
788
|
|
|
773
789
|
|
|
774
|
-
@constexpr
|
|
775
790
|
def expanded_shape(shape, expand_size):
|
|
776
791
|
return (1,)*expand_size + shape
|
|
777
792
|
|
|
@@ -804,12 +819,20 @@ def is_slice(x):
|
|
|
804
819
|
return isinstance(x, slice)
|
|
805
820
|
|
|
806
821
|
|
|
807
|
-
@
|
|
822
|
+
@_primexpr
|
|
808
823
|
def filter_expanded_dims(shape, not_expanded_dim):
|
|
824
|
+
"""filter_expanded_dims"""
|
|
825
|
+
def _check(diff, shape):
|
|
826
|
+
if diff < 0:
|
|
827
|
+
raise ValueError(f'unable to broadcast {shape}')
|
|
828
|
+
|
|
809
829
|
diff = len(not_expanded_dim) - len(shape)
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
830
|
+
_check(diff, shape)
|
|
831
|
+
res = list()
|
|
832
|
+
for i, flag in zip(shape, not_expanded_dim[diff:]):
|
|
833
|
+
if flag:
|
|
834
|
+
res.append(i)
|
|
835
|
+
return tuple(res)
|
|
813
836
|
|
|
814
837
|
|
|
815
838
|
@constexpr
|
|
@@ -818,7 +841,7 @@ def sequence_to_index(sequence, dim_size):
|
|
|
818
841
|
if not sequence:
|
|
819
842
|
return False
|
|
820
843
|
if all(isinstance(i, bool) for i in sequence):
|
|
821
|
-
if dim_size
|
|
844
|
+
if dim_size is None:
|
|
822
845
|
raise IndexError("Not supported to take the subscript of dynamic shape tensor using Boolean type")
|
|
823
846
|
seq_size = len(sequence)
|
|
824
847
|
if seq_size != dim_size:
|
|
@@ -829,19 +852,30 @@ def sequence_to_index(sequence, dim_size):
|
|
|
829
852
|
return make_tensor(sequence, mstype.int64, None, dim_size)
|
|
830
853
|
|
|
831
854
|
|
|
832
|
-
@
|
|
855
|
+
@_primexpr
|
|
833
856
|
def int_to_index(i, shape):
|
|
834
857
|
"""Converts integer to tensor indices."""
|
|
858
|
+
def _check(i, dim_size):
|
|
859
|
+
if i < -dim_size or i >= dim_size:
|
|
860
|
+
raise IndexError(f'index {i} is out of bounds for axis 0 with size {dim_size}')
|
|
861
|
+
|
|
835
862
|
dim_size = shape[0]
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
i = i % dim_size
|
|
863
|
+
_check(i, dim_size)
|
|
864
|
+
i = (i + dim_size) % dim_size
|
|
839
865
|
if len(shape) == 1:
|
|
840
|
-
return
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
866
|
+
return P.Fill()(mstype.int64, (1, 1), i)
|
|
867
|
+
mesh = list()
|
|
868
|
+
ndim = len(shape) - 1
|
|
869
|
+
for j, size in enumerate(shape[1:]):
|
|
870
|
+
grid = P.Range()(Tensor(0, mstype.int64), P.Fill()(mstype.int64, (), size), Tensor(1, mstype.int64))
|
|
871
|
+
mesh.append(P.Reshape()(grid, tuple([size if j == t else 1 for t in range(ndim)])))
|
|
872
|
+
shapes = map(P.Shape(), mesh)
|
|
873
|
+
out_shape = infer_out_shape(*shapes)
|
|
874
|
+
mesh_arrays = list()
|
|
875
|
+
for arr in mesh:
|
|
876
|
+
mesh_arrays.append(P.BroadcastTo(out_shape)(arr))
|
|
877
|
+
index = P.Stack(-1)(mesh_arrays)
|
|
878
|
+
return P.Concat(-1)((P.Fill()(mstype.int64, P.Shape()(index)[:-1] + (1,), i), index))
|
|
845
879
|
|
|
846
880
|
|
|
847
881
|
@constexpr
|
|
@@ -863,12 +897,12 @@ def rem_not_expanded_dims(idx_advanced, expand_true, tensor_index_ndim, rem_ndim
|
|
|
863
897
|
return not_expanded_dim, idx_advanced
|
|
864
898
|
|
|
865
899
|
|
|
866
|
-
@
|
|
900
|
+
@_primexpr
|
|
867
901
|
def check_slice_empty(start, stop, step):
|
|
868
902
|
return (start - stop)*step >= 0
|
|
869
903
|
|
|
870
904
|
|
|
871
|
-
@
|
|
905
|
+
@_primexpr
|
|
872
906
|
def real_axes(ndim_orig, ndim_out, axes_orig):
|
|
873
907
|
"""Returns the real axes to be reduced after performing broadcast"""
|
|
874
908
|
_diff = ndim_out - ndim_orig
|
|
@@ -880,7 +914,7 @@ def real_axes(ndim_orig, ndim_out, axes_orig):
|
|
|
880
914
|
check_axis_valid_const = constexpr(validator.check_axis_valid)
|
|
881
915
|
|
|
882
916
|
|
|
883
|
-
@
|
|
917
|
+
@_primexpr
|
|
884
918
|
def compute_slice_shape(slice_shape, broadcast_shape_len, slice_cnt, fancy_position):
|
|
885
919
|
"""Computes slice tensor shapes"""
|
|
886
920
|
shape = [1] * len(slice_shape)
|
|
@@ -889,18 +923,19 @@ def compute_slice_shape(slice_shape, broadcast_shape_len, slice_cnt, fancy_posit
|
|
|
889
923
|
return shape
|
|
890
924
|
|
|
891
925
|
|
|
892
|
-
@
|
|
926
|
+
@_primexpr
|
|
893
927
|
def infer_out_shape(*shapes):
|
|
894
928
|
"""
|
|
895
929
|
Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
|
|
896
930
|
"""
|
|
897
|
-
shape_out =
|
|
898
|
-
|
|
899
|
-
|
|
931
|
+
shape_out = list()
|
|
932
|
+
max_len = max([len(it) for it in shapes])
|
|
933
|
+
|
|
934
|
+
for i in range(max_len):
|
|
935
|
+
items = [it[i-max_len+len(it)] if i-max_len +
|
|
936
|
+
len(it) >= 0 else 1 for it in shapes]
|
|
900
937
|
max_size = 0 if 0 in items else max(items)
|
|
901
|
-
|
|
902
|
-
raise ValueError(f'operands could not be broadcast together with shapes {*shapes,}')
|
|
903
|
-
shape_out.appendleft(max_size)
|
|
938
|
+
shape_out.append(max_size)
|
|
904
939
|
return tuple(shape_out)
|
|
905
940
|
|
|
906
941
|
|
|
@@ -913,6 +948,12 @@ def use_copy_slice(tuple_index):
|
|
|
913
948
|
return False
|
|
914
949
|
|
|
915
950
|
|
|
951
|
+
@constexpr
|
|
952
|
+
def is_ascend():
|
|
953
|
+
"""Device target is Ascend or not"""
|
|
954
|
+
return context.get_context('device_target') == "Ascend"
|
|
955
|
+
|
|
956
|
+
|
|
916
957
|
@constexpr
|
|
917
958
|
def gen_exception_msg(msg_format, *args):
|
|
918
959
|
return msg_format.format(*args)
|
|
@@ -938,8 +979,22 @@ def get_output_dtype(dtype_1, dtype_2, use_complex=False):
|
|
|
938
979
|
|
|
939
980
|
@constexpr
|
|
940
981
|
def promote_binary_dtype(dtype_1, dtype_2):
|
|
982
|
+
"""
|
|
983
|
+
promote binary types
|
|
984
|
+
"""
|
|
941
985
|
if dtype_1 == dtype_2:
|
|
942
986
|
return dtype_1
|
|
943
987
|
if dtype_1 in complex_types or dtype_2 in complex_types:
|
|
944
988
|
return get_output_dtype(dtype_1, dtype_2, True)
|
|
945
989
|
return get_output_dtype(dtype_1, dtype_2, False)
|
|
990
|
+
|
|
991
|
+
|
|
992
|
+
@_primexpr
|
|
993
|
+
def generate_padding_shape(shape, length):
|
|
994
|
+
"""
|
|
995
|
+
pad the `shape` to `length` with 1.
|
|
996
|
+
"""
|
|
997
|
+
if len(shape) > length:
|
|
998
|
+
raise ValueError(f"Can not pad {shape} to length {length}.")
|
|
999
|
+
|
|
1000
|
+
return shape + (1,) * (length - len(shape))
|
|
@@ -21,6 +21,7 @@ from mindspore.ops.composite import base
|
|
|
21
21
|
from mindspore.ops import functional as F
|
|
22
22
|
from mindspore.ops.composite.multitype_ops._constexpr_utils import make_tensor, check_equal
|
|
23
23
|
from mindspore.common import CSRTensor, COOTensor
|
|
24
|
+
from ...operations._sequence_ops import SequenceAdd
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
add = base.MultitypeFuncGraph('add', True)
|
|
@@ -195,6 +196,8 @@ def _list_add_list(x, y):
|
|
|
195
196
|
Returns:
|
|
196
197
|
list, has the same dtype as x.
|
|
197
198
|
"""
|
|
199
|
+
if F.is_sequence_shape_unknown(x) or F.is_sequence_shape_unknown(y):
|
|
200
|
+
return SequenceAdd()(x, y)
|
|
198
201
|
for i in y:
|
|
199
202
|
x.append(i)
|
|
200
203
|
return x
|
|
@@ -272,6 +275,8 @@ def _add_tuple(x, y):
|
|
|
272
275
|
Returns:
|
|
273
276
|
Tuple, consists of elements of x and elements of y.
|
|
274
277
|
"""
|
|
278
|
+
if F.is_sequence_shape_unknown(x) or F.is_sequence_shape_unknown(y):
|
|
279
|
+
return SequenceAdd()(x, y)
|
|
275
280
|
return _tuple_add(x, y)
|
|
276
281
|
|
|
277
282
|
|
|
@@ -304,7 +309,7 @@ def _add_cootensor(x, y):
|
|
|
304
309
|
COOTensor, consists of elements of x and elements of y.
|
|
305
310
|
"""
|
|
306
311
|
check_equal(x.shape, y.shape, "input1 (shape={}) and input2(shape={}) should be the same shape.")
|
|
307
|
-
return F.
|
|
312
|
+
return F.coo_add(x, y, make_tensor(0, x.values.dtype))
|
|
308
313
|
|
|
309
314
|
|
|
310
315
|
@add.register("COOTensor", "Tensor")
|
|
@@ -20,7 +20,7 @@ from mindspore.ops.composite.multitype_ops import _compile_utils as utils
|
|
|
20
20
|
from mindspore.ops.composite.multitype_ops._constexpr_utils import log_warning, check_equal
|
|
21
21
|
from mindspore.ops.composite import base
|
|
22
22
|
from mindspore.ops import functional as F
|
|
23
|
-
from mindspore.common import
|
|
23
|
+
from mindspore.common import COOTensor
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
div = base.MultitypeFuncGraph("div", True)
|
|
@@ -39,8 +39,7 @@ def _csrtensor_div_tensor(x, y):
|
|
|
39
39
|
CSRTensor, equal to x / y.
|
|
40
40
|
"""
|
|
41
41
|
log_warning("For CSR divide, zero values in the dense tensor are ignored.")
|
|
42
|
-
|
|
43
|
-
return CSRTensor(x.indptr, x.indices, data, x.shape)
|
|
42
|
+
return F.csr_div(x, y)
|
|
44
43
|
|
|
45
44
|
|
|
46
45
|
@div.register("COOTensor", "Tensor")
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2020-
|
|
1
|
+
# Copyright 2020-2023 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -20,13 +20,20 @@ from mindspore.ops.operations import _map_tensor_ops
|
|
|
20
20
|
from mindspore.ops.composite.multitype_ops import _compile_utils as compile_utils
|
|
21
21
|
from mindspore.ops.composite import base
|
|
22
22
|
from mindspore.ops import functional as F
|
|
23
|
+
from mindspore.ops.operations._inner_ops import SliceGetItem
|
|
24
|
+
from ...operations._sequence_ops import SequenceSlice
|
|
23
25
|
|
|
24
|
-
|
|
26
|
+
DOC_URL = "https://mindspore.cn/docs/zh-CN/master/note/index_support.html"
|
|
27
|
+
|
|
28
|
+
getitem = base.MultitypeFuncGraph('getitem', doc_url=DOC_URL)
|
|
25
29
|
"""
|
|
26
30
|
getitem is a metafuncgraph object which will get item from an object according to input type
|
|
27
31
|
using ".register" decorator.
|
|
28
32
|
"""
|
|
29
33
|
|
|
34
|
+
slice_getitem = SliceGetItem()
|
|
35
|
+
sequence_slice = SequenceSlice()
|
|
36
|
+
|
|
30
37
|
|
|
31
38
|
class _TupleSlice(base.SequenceSliceGetItem_):
|
|
32
39
|
"""
|
|
@@ -126,6 +133,17 @@ def _tuple_getitem_by_slice(data, slice_index):
|
|
|
126
133
|
Outputs:
|
|
127
134
|
Tuple, element type is the same as the element type of data.
|
|
128
135
|
"""
|
|
136
|
+
if F.is_sequence_shape_unknown(data) or not F.isconstant(slice_index):
|
|
137
|
+
start = slice_getitem(slice_index, "start")
|
|
138
|
+
stop = slice_getitem(slice_index, "stop")
|
|
139
|
+
step = slice_getitem(slice_index, "step")
|
|
140
|
+
if start is None:
|
|
141
|
+
start = 0
|
|
142
|
+
if step is None:
|
|
143
|
+
step = 1
|
|
144
|
+
if stop is None:
|
|
145
|
+
stop = (2**31-1) if step >= 1 else -(2**31-1)
|
|
146
|
+
return sequence_slice(data, start, stop, step)
|
|
129
147
|
return _tuple_slice(data, slice_index)
|
|
130
148
|
|
|
131
149
|
|
|
@@ -141,7 +159,6 @@ def _tuple_getitem_by_tensor(data, tensor_index):
|
|
|
141
159
|
Outputs:
|
|
142
160
|
Type, is the same as the element type of data.
|
|
143
161
|
"""
|
|
144
|
-
tensor_index = F.select(tensor_index >= 0, tensor_index, tensor_index + len(data))
|
|
145
162
|
return _tuple_get_item_tensor(data, tensor_index)
|
|
146
163
|
|
|
147
164
|
|
|
@@ -172,6 +189,17 @@ def _list_getitem_by_slice(data, slice_index):
|
|
|
172
189
|
Outputs:
|
|
173
190
|
List, element type is the same as the element type of data.
|
|
174
191
|
"""
|
|
192
|
+
if F.is_sequence_shape_unknown(data) or not F.isconstant(slice_index):
|
|
193
|
+
start = slice_getitem(slice_index, "start")
|
|
194
|
+
stop = slice_getitem(slice_index, "stop")
|
|
195
|
+
step = slice_getitem(slice_index, "step")
|
|
196
|
+
if start is None:
|
|
197
|
+
start = 0
|
|
198
|
+
if step is None:
|
|
199
|
+
step = 1
|
|
200
|
+
if stop is None:
|
|
201
|
+
stop = (2**31-1) if step >= 1 else -(2**31-1)
|
|
202
|
+
return sequence_slice(data, start, stop, step)
|
|
175
203
|
return _list_slice(data, slice_index)
|
|
176
204
|
|
|
177
205
|
|
|
@@ -17,6 +17,7 @@
|
|
|
17
17
|
from mindspore.ops.composite import base
|
|
18
18
|
from mindspore.ops import functional as F
|
|
19
19
|
from mindspore.ops.operations import _inner_ops as inner
|
|
20
|
+
from ...operations._sequence_ops import tuple_greater_equal, list_greater_equal
|
|
20
21
|
|
|
21
22
|
# greater_equal is a metagraph object which will determine if two objects are greater_equal according to input type
|
|
22
23
|
# using ".register" decorator
|
|
@@ -68,3 +69,33 @@ def _greater_equal_tensor(x, y):
|
|
|
68
69
|
Tensor, return value by operator P.GreaterEqual.
|
|
69
70
|
"""
|
|
70
71
|
return F.tensor_ge(x, y)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@greater_equal.register("Tuple", "Tuple")
|
|
75
|
+
def _greater_equal_tuple(x, y):
|
|
76
|
+
"""
|
|
77
|
+
Determine whether x is greater than or equal to y.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
x(Tuple): Tuple.
|
|
81
|
+
y(Tuple): Tuple.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
bool, if x >= y return true in python logic, x < y return false.
|
|
85
|
+
"""
|
|
86
|
+
return tuple_greater_equal()(x, y)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@greater_equal.register("List", "List")
|
|
90
|
+
def _greater_equal_list(x, y):
|
|
91
|
+
"""
|
|
92
|
+
Determine whether x is greater than or equal to y.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
x(List): List.
|
|
96
|
+
y(List): List.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
bool, if x >= y return true in python logic, x < y return false.
|
|
100
|
+
"""
|
|
101
|
+
return list_greater_equal()(x, y)
|
|
@@ -18,6 +18,7 @@
|
|
|
18
18
|
from mindspore.ops.composite import base
|
|
19
19
|
from mindspore.ops import functional as F
|
|
20
20
|
from mindspore.ops.operations import _inner_ops as inner
|
|
21
|
+
from ...operations._sequence_ops import tuple_greater_than, list_greater_than
|
|
21
22
|
|
|
22
23
|
# greater is a metafuncgraph object which will determine if two objects are greater according to input type
|
|
23
24
|
# using ".register" decorator
|
|
@@ -69,3 +70,33 @@ def _greater_tensor(x, y):
|
|
|
69
70
|
tensor, return operation of x and y by P.Greater.
|
|
70
71
|
"""
|
|
71
72
|
return F.tensor_gt(x, y)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@greater.register("Tuple", "Tuple")
|
|
76
|
+
def _greater_than_tuple(x, y):
|
|
77
|
+
"""
|
|
78
|
+
Determine whether x is greater than y.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
x(Tuple): Tuple.
|
|
82
|
+
y(Tuple): Tuple.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
bool, if x > y return true in python logic, x <= y return false.
|
|
86
|
+
"""
|
|
87
|
+
return tuple_greater_than()(x, y)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@greater.register("List", "List")
|
|
91
|
+
def _greater_than_list(x, y):
|
|
92
|
+
"""
|
|
93
|
+
Determine whether x is greater than y.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
x(List): List.
|
|
97
|
+
y(List): List.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
bool, if x > y return true in python logic, x <= y return false.
|
|
101
|
+
"""
|
|
102
|
+
return list_greater_than()(x, y)
|