mindspore 2.0.0a0__cp39-cp39-win_amd64.whl → 2.0.0rc1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -2
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +16 -1
- mindspore/_extends/parse/parser.py +107 -22
- mindspore/_extends/parse/resources.py +0 -7
- mindspore/_extends/parse/standard_method.py +885 -413
- mindspore/amp.py +52 -57
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +38 -20
- mindspore/boost/dim_reduce.py +3 -3
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -6
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +41 -7
- mindspore/common/api.py +215 -141
- mindspore/common/dtype.py +8 -1
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +4 -2
- mindspore/common/jit_config.py +17 -13
- mindspore/common/mutable.py +33 -13
- mindspore/common/parameter.py +23 -21
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +62 -41
- mindspore/common/tensor.py +852 -1154
- mindspore/communication/__init__.py +2 -2
- mindspore/communication/_comm_helper.py +11 -4
- mindspore/communication/management.py +22 -21
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +201 -23
- mindspore/dataset/__init__.py +6 -6
- mindspore/dataset/audio/__init__.py +7 -7
- mindspore/dataset/audio/transforms.py +670 -30
- mindspore/dataset/audio/utils.py +47 -4
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +210 -14
- mindspore/dataset/core/validator_helpers.py +2 -2
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +322 -66
- mindspore/dataset/engine/datasets_audio.py +80 -76
- mindspore/dataset/engine/datasets_standard_format.py +51 -38
- mindspore/dataset/engine/datasets_text.py +232 -118
- mindspore/dataset/engine/datasets_user_defined.py +41 -17
- mindspore/dataset/engine/datasets_vision.py +746 -225
- mindspore/dataset/engine/graphdata.py +75 -10
- mindspore/dataset/engine/iterators.py +45 -5
- mindspore/dataset/engine/offload.py +48 -28
- mindspore/dataset/engine/validators.py +117 -8
- mindspore/dataset/text/__init__.py +6 -5
- mindspore/dataset/text/transforms.py +86 -3
- mindspore/dataset/text/utils.py +6 -4
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +3 -2
- mindspore/dataset/transforms/c_transforms.py +1 -1
- mindspore/dataset/transforms/transforms.py +2 -2
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +2 -3
- mindspore/dataset/vision/c_transforms.py +9 -9
- mindspore/dataset/vision/py_transforms.py +5 -5
- mindspore/dataset/vision/py_transforms_util.py +2 -0
- mindspore/dataset/vision/transforms.py +160 -161
- mindspore/dataset/vision/utils.py +3 -3
- mindspore/experimental/map_parameter.py +38 -26
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +44 -9
- mindspore/include/api/delegate.h +1 -1
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_parallel_runner.h +2 -2
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +19 -3
- mindspore/include/api/types.h +3 -3
- mindspore/include/dataset/constants.h +7 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filereader.py +18 -0
- mindspore/mindrecord/filewriter.py +197 -34
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
- mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
- mindspore/mindrecord/tools/csv_to_mr.py +3 -3
- mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/__init__.py +0 -4
- mindspore/nn/cell.py +204 -132
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +7 -6
- mindspore/nn/layer/__init__.py +5 -4
- mindspore/nn/layer/activation.py +40 -89
- mindspore/nn/layer/basic.py +255 -624
- mindspore/nn/layer/channel_shuffle.py +7 -6
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +41 -4
- mindspore/nn/layer/conv.py +64 -28
- mindspore/nn/layer/dense.py +9 -8
- mindspore/nn/layer/embedding.py +27 -25
- mindspore/nn/layer/image.py +53 -46
- mindspore/nn/layer/math.py +97 -105
- mindspore/nn/layer/normalization.py +117 -86
- mindspore/nn/layer/padding.py +185 -95
- mindspore/nn/layer/pooling.py +817 -414
- mindspore/nn/layer/rnn_cells.py +10 -15
- mindspore/nn/layer/rnns.py +37 -38
- mindspore/nn/layer/thor_layer.py +11 -12
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +5 -4
- mindspore/nn/loss/loss.py +334 -199
- mindspore/nn/optim/ada_grad.py +6 -6
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +4 -5
- mindspore/nn/optim/adam.py +126 -62
- mindspore/nn/optim/adamax.py +3 -4
- mindspore/nn/optim/adasum.py +6 -6
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +67 -38
- mindspore/nn/optim/lamb.py +4 -5
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +43 -4
- mindspore/nn/optim/momentum.py +6 -5
- mindspore/nn/optim/optimizer.py +3 -1
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +1 -1
- mindspore/nn/optim/rprop.py +8 -9
- mindspore/nn/optim/sgd.py +19 -13
- mindspore/nn/optim/thor.py +10 -15
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +4 -4
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +9 -15
- mindspore/nn/probability/distribution/bernoulli.py +3 -3
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +5 -7
- mindspore/nn/probability/distribution/cauchy.py +3 -3
- mindspore/nn/probability/distribution/distribution.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +3 -3
- mindspore/nn/probability/distribution/half_normal.py +15 -11
- mindspore/nn/probability/distribution/laplace.py +16 -13
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/normal.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/student_t.py +20 -15
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +27 -10
- mindspore/nn/wrap/grad_reducer.py +2 -2
- mindspore/nn/wrap/loss_scale.py +40 -24
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +35 -30
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +22 -19
- mindspore/numpy/utils.py +1 -1
- mindspore/numpy/utils_const.py +108 -58
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +86 -117
- mindspore/ops/_grad/grad_base.py +23 -1
- mindspore/ops/_grad/grad_clip_ops.py +2 -3
- mindspore/ops/_grad/grad_comm_ops.py +34 -24
- mindspore/ops/_grad/grad_implementations.py +9 -45
- mindspore/ops/_grad/grad_inner_ops.py +47 -4
- mindspore/ops/_grad/grad_math_ops.py +142 -117
- mindspore/ops/_grad/grad_nn_ops.py +71 -165
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +7 -6
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
- mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
- mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
- mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -611
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_register_for_op.py +1 -0
- mindspore/ops/_utils/__init__.py +1 -2
- mindspore/ops/_utils/utils.py +19 -40
- mindspore/ops/_vmap/vmap_array_ops.py +116 -38
- mindspore/ops/_vmap/vmap_base.py +16 -9
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
- mindspore/ops/_vmap/vmap_image_ops.py +12 -5
- mindspore/ops/_vmap/vmap_math_ops.py +46 -5
- mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
- mindspore/ops/_vmap/vmap_random_ops.py +1 -1
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
- mindspore/ops/composite/__init__.py +7 -8
- mindspore/ops/composite/base.py +101 -47
- mindspore/ops/composite/math_ops.py +188 -158
- mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
- mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
- mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
- mindspore/ops/function/__init__.py +152 -8
- mindspore/ops/function/array_func.py +2555 -674
- mindspore/ops/function/clip_func.py +209 -13
- mindspore/ops/function/debug_func.py +2 -2
- mindspore/ops/function/grad/__init__.py +2 -1
- mindspore/ops/function/grad/grad_func.py +147 -62
- mindspore/ops/function/image_func.py +54 -38
- mindspore/ops/function/linalg_func.py +167 -16
- mindspore/ops/function/math_func.py +4849 -1492
- mindspore/ops/function/nn_func.py +2573 -988
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +3 -3
- mindspore/ops/function/random_func.py +790 -73
- mindspore/ops/function/sparse_func.py +98 -78
- mindspore/ops/function/sparse_unary_func.py +54 -53
- mindspore/ops/function/spectral_func.py +27 -24
- mindspore/ops/function/vmap_func.py +22 -2
- mindspore/ops/functional.py +97 -37
- mindspore/ops/op_info_register.py +70 -28
- mindspore/ops/operations/__init__.py +47 -14
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +5 -5
- mindspore/ops/operations/_grad_ops.py +276 -187
- mindspore/ops/operations/_inner_ops.py +319 -113
- mindspore/ops/operations/_ms_kernel.py +10 -8
- mindspore/ops/operations/_ocr_ops.py +9 -9
- mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
- mindspore/ops/operations/_quant_ops.py +137 -102
- mindspore/ops/operations/_rl_inner_ops.py +121 -60
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1004 -2
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +801 -466
- mindspore/ops/operations/comm_ops.py +51 -49
- mindspore/ops/operations/control_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +123 -44
- mindspore/ops/operations/debug_ops.py +24 -24
- mindspore/ops/operations/image_ops.py +240 -153
- mindspore/ops/operations/inner_ops.py +34 -50
- mindspore/ops/operations/linalg_ops.py +31 -9
- mindspore/ops/operations/math_ops.py +988 -757
- mindspore/ops/operations/nn_ops.py +965 -819
- mindspore/ops/operations/other_ops.py +51 -40
- mindspore/ops/operations/random_ops.py +204 -122
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +254 -93
- mindspore/ops/operations/spectral_ops.py +35 -3
- mindspore/ops/primitive.py +111 -9
- mindspore/parallel/_auto_parallel_context.py +189 -83
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +99 -7
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +7 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
- mindspore/parallel/_utils.py +1 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +37 -34
- mindspore/parallel/shard.py +17 -18
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +69 -47
- mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
- mindspore/profiler/parser/base_timeline_generator.py +49 -56
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
- mindspore/profiler/parser/hwts_log_parser.py +1 -1
- mindspore/profiler/parser/integrator.py +15 -14
- mindspore/profiler/parser/minddata_analyzer.py +2 -2
- mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +2 -1
- mindspore/profiler/profiling.py +218 -186
- mindspore/rewrite/__init__.py +3 -1
- mindspore/rewrite/api/node.py +1 -114
- mindspore/rewrite/api/node_type.py +3 -0
- mindspore/rewrite/api/pattern_engine.py +31 -1
- mindspore/rewrite/api/scoped_value.py +4 -4
- mindspore/rewrite/api/symbol_tree.py +3 -78
- mindspore/rewrite/api/tree_node_helper.py +1 -1
- mindspore/rewrite/ast_creator_register.py +1 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -2
- mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
- mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
- mindspore/rewrite/namespace.py +0 -2
- mindspore/rewrite/node.py +157 -11
- mindspore/rewrite/parsers/assign_parser.py +231 -53
- mindspore/rewrite/parsers/class_def_parser.py +187 -109
- mindspore/rewrite/parsers/for_parser.py +24 -14
- mindspore/rewrite/parsers/function_def_parser.py +21 -4
- mindspore/rewrite/parsers/if_parser.py +6 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +256 -133
- mindspore/rewrite/symbol_tree_builder.py +38 -1
- mindspore/run_check/_check_version.py +69 -63
- mindspore/run_check/run_check.py +2 -1
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +1 -1
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +273 -102
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +2 -2
- mindspore/train/callback/_checkpoint.py +3 -3
- mindspore/train/callback/_early_stop.py +3 -3
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +29 -31
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +3 -3
- mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
- mindspore/train/callback/_summary_collector.py +23 -16
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +15 -3
- mindspore/train/dataset_helper.py +10 -15
- mindspore/train/loss_scale_manager.py +8 -11
- mindspore/train/metrics/__init__.py +1 -1
- mindspore/train/metrics/bleu_score.py +1 -1
- mindspore/train/metrics/confusion_matrix.py +1 -1
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/dice.py +2 -2
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +4 -3
- mindspore/train/metrics/mean_surface_distance.py +2 -2
- mindspore/train/metrics/occlusion_sensitivity.py +1 -1
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +1 -1
- mindspore/train/metrics/recall.py +1 -1
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +45 -28
- mindspore/train/serialization.py +295 -188
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -13
- mindspore/train/train_thor/convert_utils.py +2 -2
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
- mindspore/compression/__init__.py +0 -19
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -515
- mindspore/compression/quant/__init__.py +0 -28
- mindspore/compression/quant/qat.py +0 -634
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -140
- mindspore/nn/probability/dpn/vae/vae.py +0 -124
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
- mindspore/ops/composite/array_ops.py +0 -241
- mindspore/ops/composite/clip_ops.py +0 -134
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -13,22 +13,22 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""math Operations."""
|
|
16
|
-
import numpy as np
|
|
17
16
|
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
|
|
18
17
|
from mindspore.common import dtype as mstype
|
|
19
|
-
from mindspore
|
|
20
|
-
from mindspore.ops.primitive import constexpr
|
|
18
|
+
from mindspore import _checkparam as validator
|
|
19
|
+
from mindspore.ops.primitive import constexpr, _primexpr
|
|
21
20
|
from mindspore.ops import functional as F
|
|
22
|
-
from mindspore.ops.operations._inner_ops import DynamicResizeNearestNeighbor
|
|
23
21
|
from mindspore.ops.function.math_func import cummin as cummin_
|
|
24
22
|
from mindspore.ops import operations as P
|
|
25
23
|
|
|
26
24
|
|
|
27
|
-
@
|
|
25
|
+
@_primexpr
|
|
28
26
|
def _check_validate_axis(axis, name):
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
27
|
+
def _check(axis):
|
|
28
|
+
if isinstance(axis, (tuple, list)):
|
|
29
|
+
for idx, item in enumerate(axis):
|
|
30
|
+
validator.check_value_type("axis[%d]" % idx, item, [int], name)
|
|
31
|
+
_check(axis)
|
|
32
32
|
axis = validator.check_value_type('axis', axis, [int, tuple, list], name)
|
|
33
33
|
return axis
|
|
34
34
|
|
|
@@ -46,24 +46,26 @@ def is_const(x):
|
|
|
46
46
|
|
|
47
47
|
def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
|
|
48
48
|
r"""
|
|
49
|
-
Count number of nonzero elements across axis of input tensor
|
|
49
|
+
Count number of nonzero elements across axis of input tensor.
|
|
50
50
|
|
|
51
51
|
Args:
|
|
52
|
-
x (Tensor): Input data is used to count non-zero numbers.
|
|
53
|
-
|
|
54
|
-
axis (Union[int, tuple(int), list(int)]): The dimensions to reduce.
|
|
55
|
-
|
|
56
|
-
keep_dims (bool):
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
52
|
+
x (Tensor): Input data is used to count non-zero numbers. With shape
|
|
53
|
+
:math:`(N,*)` where :math:`*` means, any number of additional dimensions.
|
|
54
|
+
axis (Union[int, tuple(int), list(int)], optional): The dimensions to reduce.
|
|
55
|
+
Default: (), reduce all dimensions.
|
|
56
|
+
keep_dims (bool, optional): Whether to maintain dimensions specified by `axis`.
|
|
57
|
+
If true, keep these reduced dimensions and the length is 1.
|
|
58
|
+
If false, don't keep these dimensions. Default: False.
|
|
59
|
+
dtype (Union[Number, mindspore.bool\_], optional): The data type of the output tensor.
|
|
60
|
+
Default: mindspore.int32.
|
|
60
61
|
|
|
61
62
|
Returns:
|
|
62
|
-
Tensor, number of nonzero element
|
|
63
|
+
Tensor, number of nonzero element across axis specified by `axis`.
|
|
64
|
+
The data type is specified by `dtype`.
|
|
63
65
|
|
|
64
66
|
Raises:
|
|
65
|
-
TypeError: If axis is not int or
|
|
66
|
-
ValueError: If axis is not in range [-x.ndim, x.ndim).
|
|
67
|
+
TypeError: If `axis` is not int, tuple or list.
|
|
68
|
+
ValueError: If any value in `axis` is not in range [-x.ndim, x.ndim).
|
|
67
69
|
|
|
68
70
|
Supported Platforms:
|
|
69
71
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -116,7 +118,7 @@ def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
|
|
|
116
118
|
return nonzero_num
|
|
117
119
|
|
|
118
120
|
|
|
119
|
-
@
|
|
121
|
+
@_primexpr
|
|
120
122
|
def _int_to_tuple_conv(axes):
|
|
121
123
|
"""
|
|
122
124
|
Converts ints to tuples in input axes, expected by most validation checks.
|
|
@@ -127,7 +129,7 @@ def _int_to_tuple_conv(axes):
|
|
|
127
129
|
return axes
|
|
128
130
|
|
|
129
131
|
|
|
130
|
-
@
|
|
132
|
+
@_primexpr
|
|
131
133
|
def _check_axes(axes, prim_name=None):
|
|
132
134
|
"""
|
|
133
135
|
Check for validity and type of axes passed to function.
|
|
@@ -160,21 +162,29 @@ def _typecheck_input(x1_type, x2_type, prim_name=None):
|
|
|
160
162
|
f"and x2_type: {x2_type}.")
|
|
161
163
|
|
|
162
164
|
|
|
163
|
-
@
|
|
165
|
+
@_primexpr
|
|
164
166
|
def _axes_int_check(x1_shape, x2_shape, axes, prim_name=None):
|
|
165
167
|
"""
|
|
166
168
|
Convert from single int axes to 2d tuple if required
|
|
167
169
|
"""
|
|
168
170
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
169
|
-
|
|
171
|
+
|
|
172
|
+
def _check_lt_zero(axes):
|
|
170
173
|
if axes < 0:
|
|
171
174
|
raise ValueError(f"{msg_prefix} 'axes' must be at least 0, but got {axes}.")
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
return [], []
|
|
175
|
+
|
|
176
|
+
def _check_len(axes, x1_shape, x2_shape):
|
|
175
177
|
if axes > len(x1_shape) or axes > len(x2_shape):
|
|
176
178
|
raise ValueError(f"{msg_prefix} 'axes' cannot be greater than the length of 'x1_shape' and 'x2_shape', "
|
|
177
179
|
f"but got 'axes': {axes}, 'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}.")
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
if isinstance(axes, int):
|
|
183
|
+
_check_lt_zero(axes)
|
|
184
|
+
if axes == 0:
|
|
185
|
+
# outer product, no input validation required
|
|
186
|
+
return [], []
|
|
187
|
+
_check_len(axes, x1_shape, x2_shape)
|
|
178
188
|
x1_ind = tuple(range(len(x1_shape))[-1 * axes:])
|
|
179
189
|
x2_ind = tuple(range(len(x2_shape))[:axes])
|
|
180
190
|
axes = tuple((x1_ind, x2_ind))
|
|
@@ -182,7 +192,7 @@ def _axes_int_check(x1_shape, x2_shape, axes, prim_name=None):
|
|
|
182
192
|
return axes
|
|
183
193
|
|
|
184
194
|
|
|
185
|
-
@
|
|
195
|
+
@_primexpr
|
|
186
196
|
def _validate_axes(x1_shape, x2_shape, axes, prim_name=None):
|
|
187
197
|
"""
|
|
188
198
|
Checks for axes having the correct length according to input, for any value in axis
|
|
@@ -190,25 +200,32 @@ def _validate_axes(x1_shape, x2_shape, axes, prim_name=None):
|
|
|
190
200
|
with given inputs.
|
|
191
201
|
"""
|
|
192
202
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
203
|
+
|
|
204
|
+
def _check_len(axes_len, shape_dim_len, x_axes):
|
|
205
|
+
if axes_len > shape_dim_len:
|
|
206
|
+
raise ValueError(f"{msg_prefix} length of element {x_axes} in 'axes' must be less than or equal to "
|
|
207
|
+
f"{shape_dim_len}, but got {axes_len}.")
|
|
208
|
+
|
|
209
|
+
def _check_value(x_axes, min_val, max_val):
|
|
210
|
+
for _, x_value in enumerate(x_axes):
|
|
211
|
+
if x_value > max_val or x_value < min_val:
|
|
212
|
+
raise ValueError(f"{msg_prefix} value in 'axes' must be in range: [{min_val}, {max_val}], "
|
|
213
|
+
f"but got {x_value}.")
|
|
214
|
+
|
|
193
215
|
shapes = [x1_shape, x2_shape]
|
|
194
216
|
|
|
195
217
|
# axis length check
|
|
196
218
|
for ix_input, x_axes in enumerate(axes):
|
|
197
219
|
axes_len = len(x_axes)
|
|
198
220
|
shape_dim_len = len(shapes[ix_input])
|
|
199
|
-
|
|
200
|
-
raise ValueError(f"{msg_prefix} length of element {x_axes} in 'axes' must be less than or equal to "
|
|
201
|
-
f"{shape_dim_len}, but got {axes_len}.")
|
|
221
|
+
_check_len(axes_len, shape_dim_len, x_axes)
|
|
202
222
|
|
|
203
223
|
# axis values range check
|
|
204
224
|
for ix_input, x_axes in enumerate(axes):
|
|
205
225
|
comp_shape = shapes[ix_input]
|
|
206
226
|
max_val = len(comp_shape) - 1
|
|
207
227
|
min_val = -1 * len(comp_shape)
|
|
208
|
-
|
|
209
|
-
if not min_val <= x_value <= max_val:
|
|
210
|
-
raise ValueError(f"{msg_prefix} value in 'axes' must be in range: [{min_val}, {max_val}], "
|
|
211
|
-
f"but got {x_value}.")
|
|
228
|
+
_check_value(x_axes, min_val, max_val)
|
|
212
229
|
|
|
213
230
|
# check axis value with input shape - both ways for axis valid
|
|
214
231
|
invalid_a = False
|
|
@@ -218,23 +235,31 @@ def _validate_axes(x1_shape, x2_shape, axes, prim_name=None):
|
|
|
218
235
|
invalid_a = True
|
|
219
236
|
if x1_shape[axes[0][i]] != x2_shape[axes[1][len(axes[0]) - 1 - i]]:
|
|
220
237
|
invalid_b = True
|
|
221
|
-
if invalid_a and invalid_b:
|
|
222
|
-
raise ValueError(f"{msg_prefix} 'i' should exist such that 'x1_shape[axes[0][i]]' is equal to "
|
|
223
|
-
f"'x2_shape[axes[1][i]]' or 'x2_shape[axes[1][len(axes[0])-1-i]]', but got "
|
|
224
|
-
f"'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}, 'axes': {axes}.")
|
|
225
238
|
|
|
239
|
+
def _check(invalid_a, invalid_b, x1_shape, x2_shape, axes):
|
|
240
|
+
if invalid_a and invalid_b:
|
|
241
|
+
raise ValueError(f"{msg_prefix} 'i' should exist such that 'x1_shape[axes[0][i]]' is equal to "
|
|
242
|
+
f"'x2_shape[axes[1][i]]' or 'x2_shape[axes[1][len(axes[0])-1-i]]', but got "
|
|
243
|
+
f"'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}, 'axes': {axes}.")
|
|
226
244
|
|
|
227
|
-
|
|
245
|
+
_check(invalid_a, invalid_b, x1_shape, x2_shape, axes)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@_primexpr
|
|
228
249
|
def _calc_new_shape(shape, axes, position=0):
|
|
229
250
|
"""
|
|
230
251
|
Calculate transpose and reshape parameters for input transformations,
|
|
231
252
|
'position' refers to whether tensor is first or second in the op.
|
|
232
253
|
"""
|
|
233
254
|
contraction_axes = tuple(i if i >= 0 else i + len(shape) for i in axes[position])
|
|
234
|
-
prod_contraction =
|
|
255
|
+
prod_contraction = 1
|
|
256
|
+
for i in contraction_axes:
|
|
257
|
+
prod_contraction *= shape[i]
|
|
235
258
|
free_axes = tuple(i for i in range(len(shape)) if i not in contraction_axes)
|
|
236
|
-
free_dims = tuple(shape[i] for i in free_axes)
|
|
237
|
-
prod_free =
|
|
259
|
+
free_dims = tuple(shape[i] if shape[i] is not None else -1 for i in free_axes)
|
|
260
|
+
prod_free = 1
|
|
261
|
+
for free_dim in free_dims:
|
|
262
|
+
prod_free *= free_dim
|
|
238
263
|
|
|
239
264
|
transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes
|
|
240
265
|
new_shape = (prod_contraction, prod_free) if position else (prod_free, prod_contraction)
|
|
@@ -294,10 +319,7 @@ def tensor_dot(x1, x2, axes):
|
|
|
294
319
|
# input validity checks
|
|
295
320
|
x1_shape = shape_op(x1)
|
|
296
321
|
x2_shape = shape_op(x2)
|
|
297
|
-
x1_type = F.dtype(x1)
|
|
298
|
-
x2_type = F.dtype(x2)
|
|
299
322
|
axes = _check_axes(axes, 'tensor_dot')
|
|
300
|
-
_typecheck_input(x1_type, x2_type, 'tensor_dot')
|
|
301
323
|
# input compatibility check & axes format update
|
|
302
324
|
axes = _axes_int_check(x1_shape, x2_shape, axes, 'tensor_dot')
|
|
303
325
|
_validate_axes(x1_shape, x2_shape, axes, 'tensor_dot')
|
|
@@ -314,7 +336,7 @@ def tensor_dot(x1, x2, axes):
|
|
|
314
336
|
return final_result
|
|
315
337
|
|
|
316
338
|
|
|
317
|
-
@
|
|
339
|
+
@_primexpr
|
|
318
340
|
def _check_invalid_input(x1_shape, x2_shape, prim_name=None):
|
|
319
341
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
320
342
|
if len(x1_shape) < 2 or len(x2_shape) < 2:
|
|
@@ -335,30 +357,30 @@ def _typecheck_input_dot(x1_type, x2_type, prim_name=None):
|
|
|
335
357
|
f"x1_type: {x1_type} and x2_type: {x2_type}.")
|
|
336
358
|
|
|
337
359
|
|
|
338
|
-
@
|
|
360
|
+
@_primexpr
|
|
339
361
|
def _get_transpose_shape(x2_shape):
|
|
340
362
|
x2_shape_range = tuple(range(len(x2_shape)))
|
|
341
363
|
x2_shape_transpose = x2_shape_range[-2:-1] + x2_shape_range[:-2] + x2_shape_range[-1:]
|
|
342
364
|
return x2_shape_transpose
|
|
343
365
|
|
|
344
366
|
|
|
345
|
-
def dot(
|
|
367
|
+
def dot(input, other):
|
|
346
368
|
"""
|
|
347
369
|
Computation a dot product between samples in two tensors.
|
|
348
370
|
|
|
349
371
|
Args:
|
|
350
|
-
|
|
372
|
+
input (Tensor): First tensor in Dot op with datatype float16 or float32,
|
|
351
373
|
The rank must be greater than or equal to 2.
|
|
352
|
-
|
|
374
|
+
other (Tensor): Second tensor in Dot op with datatype float16 or float32,
|
|
353
375
|
The rank must be greater than or equal to 2.
|
|
354
376
|
|
|
355
377
|
Returns:
|
|
356
|
-
Tensor, dot product of
|
|
378
|
+
Tensor, dot product of input and other.
|
|
357
379
|
|
|
358
380
|
Raises:
|
|
359
|
-
TypeError: If type of
|
|
360
|
-
TypeError: If dtype of
|
|
361
|
-
ValueError: If rank of
|
|
381
|
+
TypeError: If type of input and other are not the same.
|
|
382
|
+
TypeError: If dtype of input or other is not float16 or float32.
|
|
383
|
+
ValueError: If rank of input or other less than 2.
|
|
362
384
|
|
|
363
385
|
Supported Platforms:
|
|
364
386
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -367,25 +389,25 @@ def dot(x1, x2):
|
|
|
367
389
|
>>> import numpy as np
|
|
368
390
|
>>> import mindspore
|
|
369
391
|
>>> from mindspore import Tensor, ops
|
|
370
|
-
>>>
|
|
371
|
-
>>>
|
|
372
|
-
>>> output = ops.dot(
|
|
392
|
+
>>> input = Tensor(np.ones(shape=[2, 3]), mindspore.float32)
|
|
393
|
+
>>> other = Tensor(np.ones(shape=[1, 3, 2]), mindspore.float32)
|
|
394
|
+
>>> output = ops.dot(input, other)
|
|
373
395
|
>>> print(output)
|
|
374
396
|
[[[3. 3.]]
|
|
375
397
|
[[3. 3.]]]
|
|
376
398
|
>>> print(output.shape)
|
|
377
399
|
(2, 1, 2)
|
|
378
|
-
>>>
|
|
379
|
-
>>>
|
|
380
|
-
>>> output = ops.dot(
|
|
400
|
+
>>> input = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
|
|
401
|
+
>>> other = Tensor(np.ones(shape=[1, 3, 2]), mindspore.float32)
|
|
402
|
+
>>> output = ops.dot(input, other)
|
|
381
403
|
>>> print(output)
|
|
382
404
|
[[[[3. 3.]]
|
|
383
405
|
[[3. 3.]]]]
|
|
384
406
|
>>> print(output.shape)
|
|
385
407
|
(1, 2, 1, 2)
|
|
386
|
-
>>>
|
|
387
|
-
>>>
|
|
388
|
-
>>> output = ops.dot(
|
|
408
|
+
>>> input = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
|
|
409
|
+
>>> other = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
|
|
410
|
+
>>> output = ops.dot(input, other)
|
|
389
411
|
>>> print(output)
|
|
390
412
|
[[[[3. 3.]
|
|
391
413
|
[3. 3.]]
|
|
@@ -393,9 +415,9 @@ def dot(x1, x2):
|
|
|
393
415
|
[3. 3.]]]]
|
|
394
416
|
>>> print(output.shape)
|
|
395
417
|
(1, 2, 2, 2)
|
|
396
|
-
>>>
|
|
397
|
-
>>>
|
|
398
|
-
>>> output = ops.dot(
|
|
418
|
+
>>> input = Tensor(np.ones(shape=[3, 2, 3]), mindspore.float32)
|
|
419
|
+
>>> other = Tensor(np.ones(shape=[2, 1, 3, 2]), mindspore.float32)
|
|
420
|
+
>>> output = ops.dot(input, other)
|
|
399
421
|
>>> print(output)
|
|
400
422
|
[[[[[3. 3.]]
|
|
401
423
|
[[3. 3.]]]
|
|
@@ -416,34 +438,36 @@ def dot(x1, x2):
|
|
|
416
438
|
reshape_op = P.Reshape()
|
|
417
439
|
transpose_op = P.Transpose()
|
|
418
440
|
matmul_op = P.MatMul(False, False)
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
_typecheck_input_dot(
|
|
424
|
-
_check_invalid_input(
|
|
425
|
-
|
|
426
|
-
if len(
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
mul_result = matmul_op(
|
|
432
|
-
reshape_shape =
|
|
441
|
+
input_shape = shape_op(input)
|
|
442
|
+
other_shape = shape_op(other)
|
|
443
|
+
input_type = F.dtype(input)
|
|
444
|
+
other_type = F.dtype(other)
|
|
445
|
+
_typecheck_input_dot(input_type, other_type, 'dot')
|
|
446
|
+
_check_invalid_input(input_shape, other_shape, 'dot')
|
|
447
|
+
|
|
448
|
+
if len(input_shape) > 2 or len(other_shape) > 2:
|
|
449
|
+
other_shape_transpose = _get_transpose_shape(other_shape)
|
|
450
|
+
other_transpose = transpose_op(other, other_shape_transpose)
|
|
451
|
+
input_reshape = reshape_op(input, (-1, input_shape[-1]))
|
|
452
|
+
other_reshape = reshape_op(other_transpose, (other_shape[-2], -1))
|
|
453
|
+
mul_result = matmul_op(input_reshape, other_reshape)
|
|
454
|
+
reshape_shape = input_shape[:-1] + other_shape[:-2] + other_shape[-1:]
|
|
433
455
|
reshape_shape = (-1,) + reshape_shape[1:]
|
|
434
456
|
return reshape_op(mul_result, reshape_shape)
|
|
435
|
-
return matmul_op(
|
|
457
|
+
return matmul_op(input, other)
|
|
436
458
|
|
|
437
459
|
|
|
438
|
-
@
|
|
460
|
+
@_primexpr
|
|
439
461
|
def _get_batch_size(x1_shape, x2_shape, prim_name=None):
|
|
440
462
|
"""
|
|
441
463
|
Get batch sizes from two inputs
|
|
442
464
|
"""
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
465
|
+
def _check():
|
|
466
|
+
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
467
|
+
if len(x1_shape) < 2 or len(x2_shape) < 2:
|
|
468
|
+
raise ValueError(f"{msg_prefix} inputs x1, x2 should have 'dimension >= 2', "
|
|
469
|
+
f"but got 'len(x1_shape)': ({len(x1_shape)}) and 'len(x2_shape)': ({len(x2_shape)}).")
|
|
470
|
+
_check()
|
|
447
471
|
return x1_shape[0], x2_shape[0]
|
|
448
472
|
|
|
449
473
|
|
|
@@ -460,12 +484,33 @@ def _typecheck_input_batch_dot(x1_type, x2_type, prim_name=None):
|
|
|
460
484
|
f"x2_type: {x2_type}.")
|
|
461
485
|
|
|
462
486
|
|
|
463
|
-
@
|
|
487
|
+
@_primexpr
|
|
464
488
|
def _check_axes_for_batch_dot(x1_shape, x2_shape, axes, prim_name=None):
|
|
465
489
|
"""
|
|
466
490
|
Check whether axes are valid and cast axes from tuple to list
|
|
467
491
|
"""
|
|
468
492
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
493
|
+
|
|
494
|
+
def _check_1(axes):
|
|
495
|
+
if 0 in axes:
|
|
496
|
+
raise ValueError(f"{msg_prefix} 'axes' cannot contain 0, but got axes: {axes}.")
|
|
497
|
+
if len(axes) != 2:
|
|
498
|
+
raise ValueError(f"{msg_prefix} length of 'axes' must be equal to 2, but got {len(axes)}.")
|
|
499
|
+
|
|
500
|
+
def _check_2(axes, x1_shape, x2_shape):
|
|
501
|
+
if axes[0] > len(x1_shape) or axes[1] > len(x2_shape):
|
|
502
|
+
raise ValueError(f"{msg_prefix} axes[0] must be less than or equal to len(x1_shape), "
|
|
503
|
+
f"and axes[1] must be less than or equal to len(x2_shape)."
|
|
504
|
+
f"But got 'axes': {axes}, 'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}.")
|
|
505
|
+
|
|
506
|
+
def _check_3(axes, x1_shape, x2_shape):
|
|
507
|
+
if axes == 0:
|
|
508
|
+
raise ValueError(f"{msg_prefix} 'axes' should not be equal to 0, but got {axes}.")
|
|
509
|
+
|
|
510
|
+
if axes > len(x1_shape) or axes > len(x2_shape):
|
|
511
|
+
raise ValueError(f"{msg_prefix} 'axes' cannot be greater than the length of 'x1_shape' and 'x2_shape', "
|
|
512
|
+
f"but got 'axes': {axes}, 'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}.")
|
|
513
|
+
|
|
469
514
|
if axes is None:
|
|
470
515
|
if len(x2_shape) == 2:
|
|
471
516
|
axes = [len(x1_shape) - 1, len(x2_shape) - 1]
|
|
@@ -473,10 +518,7 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes, prim_name=None):
|
|
|
473
518
|
axes = [len(x1_shape) - 1, len(x2_shape) - 2]
|
|
474
519
|
|
|
475
520
|
if isinstance(axes, (list, tuple)):
|
|
476
|
-
|
|
477
|
-
raise ValueError(f"{msg_prefix} 'axes' cannot contain 0, but got axes: {axes}.")
|
|
478
|
-
if len(axes) != 2:
|
|
479
|
-
raise ValueError(f"{msg_prefix} length of 'axes' must be equal to 2, but got {len(axes)}.")
|
|
521
|
+
_check_1(axes)
|
|
480
522
|
if isinstance(axes, tuple):
|
|
481
523
|
axes = list(axes)
|
|
482
524
|
validator.check_value_type('axes[0]', axes[0], [int], 'batch_dot')
|
|
@@ -488,19 +530,12 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes, prim_name=None):
|
|
|
488
530
|
axes[1] += len(x2_shape)
|
|
489
531
|
validator.check_non_negative_int(axes[0], 'reversed axes[0]', 'batch_dot')
|
|
490
532
|
validator.check_non_negative_int(axes[1], 'reversed axes[1]', 'batch_dot')
|
|
491
|
-
|
|
492
|
-
raise ValueError(f"{msg_prefix} axes[0] must be less than or equal to len(x1_shape), "
|
|
493
|
-
f"and axes[1] must be less than or equal to len(x2_shape)."
|
|
494
|
-
f"But got 'axes': {axes}, 'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}.")
|
|
533
|
+
_check_2(axes, x1_shape, x2_shape)
|
|
495
534
|
elif isinstance(axes, int):
|
|
496
|
-
|
|
497
|
-
raise ValueError(f"{msg_prefix} 'axes' should not be equal to 0, but got {axes}.")
|
|
535
|
+
_check_3(axes, x1_shape, x2_shape)
|
|
498
536
|
if axes < 0:
|
|
499
537
|
axes = [axes + len(x1_shape), axes + len(x2_shape)]
|
|
500
538
|
validator.check_non_negative_int(axes[0], 'reversed axes', 'batch_dot')
|
|
501
|
-
elif axes > len(x1_shape) or axes > len(x2_shape):
|
|
502
|
-
raise ValueError(f"{msg_prefix} 'axes' cannot be greater than the length of 'x1_shape' and 'x2_shape', "
|
|
503
|
-
f"but got 'axes': {axes}, 'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}.")
|
|
504
539
|
else:
|
|
505
540
|
axes = [axes, axes]
|
|
506
541
|
else:
|
|
@@ -509,7 +544,7 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes, prim_name=None):
|
|
|
509
544
|
return axes
|
|
510
545
|
|
|
511
546
|
|
|
512
|
-
@
|
|
547
|
+
@_primexpr
|
|
513
548
|
def _calc_new_shape_batchdot(shape, axes, position=0):
|
|
514
549
|
"""
|
|
515
550
|
Calculate transpose and reshape parameters for input transformations,
|
|
@@ -517,10 +552,14 @@ def _calc_new_shape_batchdot(shape, axes, position=0):
|
|
|
517
552
|
"""
|
|
518
553
|
axis = axes[position]
|
|
519
554
|
contraction_axes = tuple([axis])
|
|
520
|
-
prod_contraction =
|
|
555
|
+
prod_contraction = 1
|
|
556
|
+
for i in contraction_axes:
|
|
557
|
+
prod_contraction *= shape[i]
|
|
521
558
|
free_axes = tuple(i for i in range(1, len(shape)) if i not in contraction_axes)
|
|
522
559
|
free_dims = tuple(shape[i] for i in free_axes)
|
|
523
|
-
prod_free =
|
|
560
|
+
prod_free = 1
|
|
561
|
+
for free_dim in free_dims:
|
|
562
|
+
prod_free *= free_dim
|
|
524
563
|
|
|
525
564
|
transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes
|
|
526
565
|
transpose_perm = tuple([0]) + transpose_perm
|
|
@@ -529,7 +568,7 @@ def _calc_new_shape_batchdot(shape, axes, position=0):
|
|
|
529
568
|
return new_shape, transpose_perm, free_dims
|
|
530
569
|
|
|
531
570
|
|
|
532
|
-
@
|
|
571
|
+
@_primexpr
|
|
533
572
|
def _check_batch_size(x1_batch_size, x2_batch_size, prim_name=None):
|
|
534
573
|
"""
|
|
535
574
|
Check whether batch size of two inputs are the same
|
|
@@ -540,7 +579,7 @@ def _check_batch_size(x1_batch_size, x2_batch_size, prim_name=None):
|
|
|
540
579
|
f"'x1_batch_size': {x1_batch_size} and 'x2_batch_size': {x2_batch_size}.")
|
|
541
580
|
|
|
542
581
|
|
|
543
|
-
@
|
|
582
|
+
@_primexpr
|
|
544
583
|
def _get_output_shape(batch_size, x1_ret, x2_ret):
|
|
545
584
|
"""
|
|
546
585
|
Compute output shape for batch dot
|
|
@@ -732,6 +771,49 @@ def matmul(x1, x2, dtype=None):
|
|
|
732
771
|
return res
|
|
733
772
|
|
|
734
773
|
|
|
774
|
+
def mm(input, mat2):
|
|
775
|
+
r"""
|
|
776
|
+
Returns the matrix product of two arrays.
|
|
777
|
+
If `input` is a :math:`(n \times m)` Tensor, `mat2` is a
|
|
778
|
+
:math:`(m \times p)` Tensor, `out` will be a :math:`(n \times p)` Tensor.
|
|
779
|
+
|
|
780
|
+
Note:
|
|
781
|
+
This function cannot support broadcasting.
|
|
782
|
+
Refer to :func:`mindspore.ops.matmul` instead if you need a broadcastable function.
|
|
783
|
+
|
|
784
|
+
Args:
|
|
785
|
+
input (Tensor): The first matrix of matrix multiplication.
|
|
786
|
+
The last dimension of `input` must be the same size as the first dimension of `mat2`.
|
|
787
|
+
mat2 (Tensor): The second matrix of matrix multiplication.
|
|
788
|
+
The last dimension of `input` must be the same size as the first dimension of `mat2`.
|
|
789
|
+
|
|
790
|
+
Returns:
|
|
791
|
+
Tensor or scalar, the matrix product of the inputs.
|
|
792
|
+
|
|
793
|
+
Raises:
|
|
794
|
+
ValueError: If the last dimension of `input` is not the same size as the
|
|
795
|
+
second-to-last dimension of `mat2`.
|
|
796
|
+
ValueError: If `input` or `mat2` is not a matrix.
|
|
797
|
+
|
|
798
|
+
Supported Platforms:
|
|
799
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
800
|
+
|
|
801
|
+
Examples:
|
|
802
|
+
>>> import mindspore as ms
|
|
803
|
+
>>> import mindspore.ops as ops
|
|
804
|
+
>>> import numpy as np
|
|
805
|
+
>>> x1 = ms.Tensor(np.random.rand(2, 3))
|
|
806
|
+
>>> x2 = ms.Tensor(np.random.rand(3, 4))
|
|
807
|
+
>>> out = ops.mm(x1, x2)
|
|
808
|
+
>>> print(out.shape)
|
|
809
|
+
(2, 4)
|
|
810
|
+
"""
|
|
811
|
+
if input.ndim != 2 or mat2.ndim != 2:
|
|
812
|
+
raise ValueError(f"For mm, the input tensor must be a matrix, "
|
|
813
|
+
f"but got mat1.ndim:{input.ndim}, mat2.ndim:{mat2.ndim}")
|
|
814
|
+
return matmul(input, mat2)
|
|
815
|
+
|
|
816
|
+
|
|
735
817
|
def cummin(x, axis):
|
|
736
818
|
r"""
|
|
737
819
|
Returns a tuple (values,indices) where 'values' is the cumulative minimum value of input Tensor `x`
|
|
@@ -770,55 +852,3 @@ def cummin(x, axis):
|
|
|
770
852
|
[0 1 1 1 4 4]
|
|
771
853
|
"""
|
|
772
854
|
return cummin_(x, axis)
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
def resize_nearest_neighbor(input_x, size, align_corners=False):
|
|
776
|
-
r"""
|
|
777
|
-
Resizes the input tensor by using the nearest neighbor algorithm.
|
|
778
|
-
|
|
779
|
-
Resizes the input tensor to a given size by using the nearest neighbor algorithm. The nearest
|
|
780
|
-
neighbor algorithm selects the value of the nearest point and does not consider the
|
|
781
|
-
values of neighboring points at all, yielding a piecewise-constant interpolant.
|
|
782
|
-
|
|
783
|
-
Args:
|
|
784
|
-
input_x (Tensor) - The input tensor. The shape of the tensor is :math:`(N, C, H, W)`.
|
|
785
|
-
size (Union[Tensor, tuple, list]): The target size. The dimension of size must be 2.
|
|
786
|
-
align_corners (bool): Whether the centers of the 4 corner pixels of the input
|
|
787
|
-
and output tensors are aligned. Default: False.
|
|
788
|
-
|
|
789
|
-
Returns:
|
|
790
|
-
Tensor, the shape of the output tensor is :math:`(N, C, NEW\_H, NEW\_W)`.
|
|
791
|
-
The data type is the same as the `input_x`.
|
|
792
|
-
|
|
793
|
-
Raises:
|
|
794
|
-
TypeError: If `input_x` is not a Tensor.
|
|
795
|
-
TypeError: If `size` is neither tuple nor list.
|
|
796
|
-
TypeError: If `align_corners` is not a bool.
|
|
797
|
-
ValueError: If length of `size` is not equal to 2.
|
|
798
|
-
|
|
799
|
-
Supported Platforms:
|
|
800
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
801
|
-
|
|
802
|
-
Examples:
|
|
803
|
-
>>> import numpy as np
|
|
804
|
-
>>> import mindspore
|
|
805
|
-
>>> from mindspore import Tensor, ops
|
|
806
|
-
>>> input_tensor = Tensor(np.array([[[[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]]]), mindspore.float32)
|
|
807
|
-
>>> size = (2, 2)
|
|
808
|
-
>>> output = ops.ResizeNearestNeighbor(size=size)(input_tensor)
|
|
809
|
-
>>> print(output)
|
|
810
|
-
[[[[-0.1 0.3]
|
|
811
|
-
[ 0.4 0.5]]]]
|
|
812
|
-
"""
|
|
813
|
-
if size is None:
|
|
814
|
-
raise ValueError(f'For ResizeNearestNeighbor, size could not be None.')
|
|
815
|
-
if isinstance(size, (tuple, list)):
|
|
816
|
-
resize = P.ResizeNearestNeighbor(size, align_corners)
|
|
817
|
-
return resize(input_x)
|
|
818
|
-
if is_const(size):
|
|
819
|
-
size = size.asnumpy()
|
|
820
|
-
resize = P.ResizeNearestNeighbor(size, align_corners)
|
|
821
|
-
return resize(input_x)
|
|
822
|
-
|
|
823
|
-
resize = DynamicResizeNearestNeighbor(align_corners)
|
|
824
|
-
return resize(input_x, size)
|