mindspore 2.0.0a0__cp38-cp38-win_amd64.whl → 2.0.0rc1__cp38-cp38-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -2
- mindspore/_c_dataengine.cp38-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp38-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp38-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +16 -1
- mindspore/_extends/parse/parser.py +107 -22
- mindspore/_extends/parse/resources.py +0 -7
- mindspore/_extends/parse/standard_method.py +885 -413
- mindspore/amp.py +52 -57
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +38 -20
- mindspore/boost/dim_reduce.py +3 -3
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -6
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +41 -7
- mindspore/common/api.py +215 -141
- mindspore/common/dtype.py +8 -1
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +4 -2
- mindspore/common/jit_config.py +17 -13
- mindspore/common/mutable.py +33 -13
- mindspore/common/parameter.py +23 -21
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +62 -41
- mindspore/common/tensor.py +852 -1154
- mindspore/communication/__init__.py +2 -2
- mindspore/communication/_comm_helper.py +11 -4
- mindspore/communication/management.py +22 -21
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +201 -23
- mindspore/dataset/__init__.py +6 -6
- mindspore/dataset/audio/__init__.py +7 -7
- mindspore/dataset/audio/transforms.py +670 -30
- mindspore/dataset/audio/utils.py +47 -4
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +210 -14
- mindspore/dataset/core/validator_helpers.py +2 -2
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +322 -66
- mindspore/dataset/engine/datasets_audio.py +80 -76
- mindspore/dataset/engine/datasets_standard_format.py +51 -38
- mindspore/dataset/engine/datasets_text.py +232 -118
- mindspore/dataset/engine/datasets_user_defined.py +41 -17
- mindspore/dataset/engine/datasets_vision.py +746 -225
- mindspore/dataset/engine/graphdata.py +75 -10
- mindspore/dataset/engine/iterators.py +45 -5
- mindspore/dataset/engine/offload.py +48 -28
- mindspore/dataset/engine/validators.py +117 -8
- mindspore/dataset/text/__init__.py +6 -5
- mindspore/dataset/text/transforms.py +86 -3
- mindspore/dataset/text/utils.py +6 -4
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +3 -2
- mindspore/dataset/transforms/c_transforms.py +1 -1
- mindspore/dataset/transforms/transforms.py +2 -2
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +2 -3
- mindspore/dataset/vision/c_transforms.py +9 -9
- mindspore/dataset/vision/py_transforms.py +5 -5
- mindspore/dataset/vision/py_transforms_util.py +2 -0
- mindspore/dataset/vision/transforms.py +160 -161
- mindspore/dataset/vision/utils.py +3 -3
- mindspore/experimental/map_parameter.py +38 -26
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +44 -9
- mindspore/include/api/delegate.h +1 -1
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_parallel_runner.h +2 -2
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +19 -3
- mindspore/include/api/types.h +3 -3
- mindspore/include/dataset/constants.h +7 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filereader.py +18 -0
- mindspore/mindrecord/filewriter.py +197 -34
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
- mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
- mindspore/mindrecord/tools/csv_to_mr.py +3 -3
- mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/__init__.py +0 -4
- mindspore/nn/cell.py +204 -132
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +7 -6
- mindspore/nn/layer/__init__.py +5 -4
- mindspore/nn/layer/activation.py +40 -89
- mindspore/nn/layer/basic.py +255 -624
- mindspore/nn/layer/channel_shuffle.py +7 -6
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +41 -4
- mindspore/nn/layer/conv.py +64 -28
- mindspore/nn/layer/dense.py +9 -8
- mindspore/nn/layer/embedding.py +27 -25
- mindspore/nn/layer/image.py +53 -46
- mindspore/nn/layer/math.py +97 -105
- mindspore/nn/layer/normalization.py +117 -86
- mindspore/nn/layer/padding.py +185 -95
- mindspore/nn/layer/pooling.py +817 -414
- mindspore/nn/layer/rnn_cells.py +10 -15
- mindspore/nn/layer/rnns.py +37 -38
- mindspore/nn/layer/thor_layer.py +11 -12
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +5 -4
- mindspore/nn/loss/loss.py +334 -199
- mindspore/nn/optim/ada_grad.py +6 -6
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +4 -5
- mindspore/nn/optim/adam.py +126 -62
- mindspore/nn/optim/adamax.py +3 -4
- mindspore/nn/optim/adasum.py +6 -6
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +67 -38
- mindspore/nn/optim/lamb.py +4 -5
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +43 -4
- mindspore/nn/optim/momentum.py +6 -5
- mindspore/nn/optim/optimizer.py +3 -1
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +1 -1
- mindspore/nn/optim/rprop.py +8 -9
- mindspore/nn/optim/sgd.py +19 -13
- mindspore/nn/optim/thor.py +10 -15
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +4 -4
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +9 -15
- mindspore/nn/probability/distribution/bernoulli.py +3 -3
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +5 -7
- mindspore/nn/probability/distribution/cauchy.py +3 -3
- mindspore/nn/probability/distribution/distribution.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +3 -3
- mindspore/nn/probability/distribution/half_normal.py +15 -11
- mindspore/nn/probability/distribution/laplace.py +16 -13
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/normal.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/student_t.py +20 -15
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +27 -10
- mindspore/nn/wrap/grad_reducer.py +2 -2
- mindspore/nn/wrap/loss_scale.py +40 -24
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +35 -30
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +22 -19
- mindspore/numpy/utils.py +1 -1
- mindspore/numpy/utils_const.py +108 -58
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +86 -117
- mindspore/ops/_grad/grad_base.py +23 -1
- mindspore/ops/_grad/grad_clip_ops.py +2 -3
- mindspore/ops/_grad/grad_comm_ops.py +34 -24
- mindspore/ops/_grad/grad_implementations.py +9 -45
- mindspore/ops/_grad/grad_inner_ops.py +47 -4
- mindspore/ops/_grad/grad_math_ops.py +142 -117
- mindspore/ops/_grad/grad_nn_ops.py +71 -165
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +7 -6
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
- mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
- mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
- mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -611
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_register_for_op.py +1 -0
- mindspore/ops/_utils/__init__.py +1 -2
- mindspore/ops/_utils/utils.py +19 -40
- mindspore/ops/_vmap/vmap_array_ops.py +116 -38
- mindspore/ops/_vmap/vmap_base.py +16 -9
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
- mindspore/ops/_vmap/vmap_image_ops.py +12 -5
- mindspore/ops/_vmap/vmap_math_ops.py +46 -5
- mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
- mindspore/ops/_vmap/vmap_random_ops.py +1 -1
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
- mindspore/ops/composite/__init__.py +7 -8
- mindspore/ops/composite/base.py +101 -47
- mindspore/ops/composite/math_ops.py +188 -158
- mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
- mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
- mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
- mindspore/ops/function/__init__.py +152 -8
- mindspore/ops/function/array_func.py +2555 -674
- mindspore/ops/function/clip_func.py +209 -13
- mindspore/ops/function/debug_func.py +2 -2
- mindspore/ops/function/grad/__init__.py +2 -1
- mindspore/ops/function/grad/grad_func.py +147 -62
- mindspore/ops/function/image_func.py +54 -38
- mindspore/ops/function/linalg_func.py +167 -16
- mindspore/ops/function/math_func.py +4849 -1492
- mindspore/ops/function/nn_func.py +2573 -988
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +3 -3
- mindspore/ops/function/random_func.py +790 -73
- mindspore/ops/function/sparse_func.py +98 -78
- mindspore/ops/function/sparse_unary_func.py +54 -53
- mindspore/ops/function/spectral_func.py +27 -24
- mindspore/ops/function/vmap_func.py +22 -2
- mindspore/ops/functional.py +97 -37
- mindspore/ops/op_info_register.py +70 -28
- mindspore/ops/operations/__init__.py +47 -14
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +5 -5
- mindspore/ops/operations/_grad_ops.py +276 -187
- mindspore/ops/operations/_inner_ops.py +319 -113
- mindspore/ops/operations/_ms_kernel.py +10 -8
- mindspore/ops/operations/_ocr_ops.py +9 -9
- mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
- mindspore/ops/operations/_quant_ops.py +137 -102
- mindspore/ops/operations/_rl_inner_ops.py +121 -60
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1004 -2
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +801 -466
- mindspore/ops/operations/comm_ops.py +51 -49
- mindspore/ops/operations/control_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +123 -44
- mindspore/ops/operations/debug_ops.py +24 -24
- mindspore/ops/operations/image_ops.py +240 -153
- mindspore/ops/operations/inner_ops.py +34 -50
- mindspore/ops/operations/linalg_ops.py +31 -9
- mindspore/ops/operations/math_ops.py +988 -757
- mindspore/ops/operations/nn_ops.py +965 -819
- mindspore/ops/operations/other_ops.py +51 -40
- mindspore/ops/operations/random_ops.py +204 -122
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +254 -93
- mindspore/ops/operations/spectral_ops.py +35 -3
- mindspore/ops/primitive.py +111 -9
- mindspore/parallel/_auto_parallel_context.py +189 -83
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +99 -7
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +7 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
- mindspore/parallel/_utils.py +1 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +37 -34
- mindspore/parallel/shard.py +17 -18
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +69 -47
- mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
- mindspore/profiler/parser/base_timeline_generator.py +49 -56
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
- mindspore/profiler/parser/hwts_log_parser.py +1 -1
- mindspore/profiler/parser/integrator.py +15 -14
- mindspore/profiler/parser/minddata_analyzer.py +2 -2
- mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +2 -1
- mindspore/profiler/profiling.py +218 -186
- mindspore/rewrite/__init__.py +3 -1
- mindspore/rewrite/api/node.py +1 -114
- mindspore/rewrite/api/node_type.py +3 -0
- mindspore/rewrite/api/pattern_engine.py +31 -1
- mindspore/rewrite/api/scoped_value.py +4 -4
- mindspore/rewrite/api/symbol_tree.py +3 -78
- mindspore/rewrite/api/tree_node_helper.py +1 -1
- mindspore/rewrite/ast_creator_register.py +1 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -2
- mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
- mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
- mindspore/rewrite/namespace.py +0 -2
- mindspore/rewrite/node.py +157 -11
- mindspore/rewrite/parsers/assign_parser.py +231 -53
- mindspore/rewrite/parsers/class_def_parser.py +187 -109
- mindspore/rewrite/parsers/for_parser.py +24 -14
- mindspore/rewrite/parsers/function_def_parser.py +21 -4
- mindspore/rewrite/parsers/if_parser.py +6 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +256 -133
- mindspore/rewrite/symbol_tree_builder.py +38 -1
- mindspore/run_check/_check_version.py +69 -63
- mindspore/run_check/run_check.py +2 -1
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +1 -1
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +273 -102
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +2 -2
- mindspore/train/callback/_checkpoint.py +3 -3
- mindspore/train/callback/_early_stop.py +3 -3
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +29 -31
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +3 -3
- mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
- mindspore/train/callback/_summary_collector.py +23 -16
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +15 -3
- mindspore/train/dataset_helper.py +10 -15
- mindspore/train/loss_scale_manager.py +8 -11
- mindspore/train/metrics/__init__.py +1 -1
- mindspore/train/metrics/bleu_score.py +1 -1
- mindspore/train/metrics/confusion_matrix.py +1 -1
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/dice.py +2 -2
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +4 -3
- mindspore/train/metrics/mean_surface_distance.py +2 -2
- mindspore/train/metrics/occlusion_sensitivity.py +1 -1
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +1 -1
- mindspore/train/metrics/recall.py +1 -1
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +45 -28
- mindspore/train/serialization.py +295 -188
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -13
- mindspore/train/train_thor/convert_utils.py +2 -2
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
- mindspore/compression/__init__.py +0 -19
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -515
- mindspore/compression/quant/__init__.py +0 -28
- mindspore/compression/quant/qat.py +0 -634
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -140
- mindspore/nn/probability/dpn/vae/vae.py +0 -124
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
- mindspore/ops/composite/array_ops.py +0 -241
- mindspore/ops/composite/clip_ops.py +0 -134
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
mindspore/nn/layer/image.py
CHANGED
|
@@ -25,8 +25,8 @@ from mindspore.common.tensor import Tensor
|
|
|
25
25
|
from mindspore.ops import operations as P
|
|
26
26
|
from mindspore.ops.operations import _inner_ops as inner
|
|
27
27
|
from mindspore.ops import functional as F
|
|
28
|
-
from mindspore.ops.primitive import constexpr
|
|
29
|
-
from mindspore
|
|
28
|
+
from mindspore.ops.primitive import constexpr, _primexpr
|
|
29
|
+
from mindspore import _checkparam as validator
|
|
30
30
|
from mindspore.nn.layer.conv import Conv2d
|
|
31
31
|
from mindspore.nn.layer.container import CellList
|
|
32
32
|
from mindspore.nn.layer.pooling import AvgPool2d
|
|
@@ -78,8 +78,7 @@ class ImageGradients(Cell):
|
|
|
78
78
|
super(ImageGradients, self).__init__()
|
|
79
79
|
|
|
80
80
|
def construct(self, images):
|
|
81
|
-
|
|
82
|
-
images = F.depend(images, check)
|
|
81
|
+
_check_input_4d(F.shape(images), "images", self.cls_name)
|
|
83
82
|
batch_size, depth, height, width = P.Shape()(images)
|
|
84
83
|
if height == 1:
|
|
85
84
|
dy = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0)
|
|
@@ -120,19 +119,18 @@ def _get_dtype_max(dtype):
|
|
|
120
119
|
return dtype_max
|
|
121
120
|
|
|
122
121
|
|
|
123
|
-
@
|
|
122
|
+
@_primexpr
|
|
124
123
|
def _check_input_4d(input_shape, param_name, func_name):
|
|
125
124
|
if len(input_shape) != 4:
|
|
126
125
|
raise ValueError(f"For '{func_name}', the dimension of '{param_name}' must be 4d, "
|
|
127
126
|
f"but got {len(input_shape)}.")
|
|
128
|
-
return True
|
|
129
127
|
|
|
130
128
|
|
|
131
|
-
@
|
|
129
|
+
@_primexpr
|
|
132
130
|
def _check_input_filter_size(input_shape, param_name, filter_size, func_name):
|
|
133
131
|
_check_input_4d(input_shape, param_name, func_name)
|
|
134
|
-
validator.check(param_name + " shape[2]", input_shape[2], "filter_size", filter_size,
|
|
135
|
-
validator.check(param_name + " shape[3]", input_shape[3], "filter_size", filter_size,
|
|
132
|
+
validator.check(param_name + " shape[2]", input_shape[2], "filter_size", filter_size, validator.GE, func_name)
|
|
133
|
+
validator.check(param_name + " shape[3]", input_shape[3], "filter_size", filter_size, validator.GE, func_name)
|
|
136
134
|
|
|
137
135
|
|
|
138
136
|
@constexpr
|
|
@@ -267,9 +265,9 @@ class SSIM(Cell):
|
|
|
267
265
|
def __init__(self, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03):
|
|
268
266
|
super(SSIM, self).__init__()
|
|
269
267
|
validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
|
|
270
|
-
validator.check_number('max_val', max_val, 0.0,
|
|
268
|
+
validator.check_number('max_val', max_val, 0.0, validator.GT, self.cls_name)
|
|
271
269
|
self.max_val = max_val
|
|
272
|
-
self.filter_size = validator.check_int(filter_size, 1,
|
|
270
|
+
self.filter_size = validator.check_int(filter_size, 1, validator.GE, 'filter_size', self.cls_name)
|
|
273
271
|
self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name)
|
|
274
272
|
self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name)
|
|
275
273
|
self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name)
|
|
@@ -363,10 +361,10 @@ class MSSSIM(Cell):
|
|
|
363
361
|
filter_sigma=1.5, k1=0.01, k2=0.03):
|
|
364
362
|
super(MSSSIM, self).__init__()
|
|
365
363
|
validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
|
|
366
|
-
validator.check_number('max_val', max_val, 0.0,
|
|
364
|
+
validator.check_number('max_val', max_val, 0.0, validator.GT, self.cls_name)
|
|
367
365
|
self.max_val = max_val
|
|
368
366
|
validator.check_value_type('power_factors', power_factors, [tuple, list], self.cls_name)
|
|
369
|
-
self.filter_size = validator.check_int(filter_size, 1,
|
|
367
|
+
self.filter_size = validator.check_int(filter_size, 1, validator.GE, 'filter_size', self.cls_name)
|
|
370
368
|
self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name)
|
|
371
369
|
self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name)
|
|
372
370
|
self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name)
|
|
@@ -462,7 +460,7 @@ class PSNR(Cell):
|
|
|
462
460
|
def __init__(self, max_val=1.0):
|
|
463
461
|
super(PSNR, self).__init__()
|
|
464
462
|
validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
|
|
465
|
-
validator.check_number('max_val', max_val, 0.0,
|
|
463
|
+
validator.check_number('max_val', max_val, 0.0, validator.GT, self.cls_name)
|
|
466
464
|
self.max_val = max_val
|
|
467
465
|
|
|
468
466
|
def construct(self, img1, img2):
|
|
@@ -481,22 +479,26 @@ class PSNR(Cell):
|
|
|
481
479
|
return psnr
|
|
482
480
|
|
|
483
481
|
|
|
484
|
-
@
|
|
485
|
-
def
|
|
482
|
+
@_primexpr
|
|
483
|
+
def _check_rank(rank, input_shape, param_name, func_name):
|
|
486
484
|
"""raise error if input is not 3d or 4d"""
|
|
487
|
-
|
|
485
|
+
def _check():
|
|
486
|
+
if rank not in (3, 4):
|
|
487
|
+
raise ValueError(f"{func_name} {param_name} must be 3d or 4d, but got shape {input_shape}")
|
|
488
|
+
_check()
|
|
488
489
|
|
|
489
490
|
|
|
490
|
-
@
|
|
491
|
+
@_primexpr
|
|
491
492
|
def _get_bbox(rank, shape, central_fraction):
|
|
492
493
|
"""get bbox start and size for slice"""
|
|
494
|
+
n, c, h, w = -1, -1, -1, -1
|
|
493
495
|
if rank == 3:
|
|
494
496
|
c, h, w = shape
|
|
495
497
|
else:
|
|
496
498
|
n, c, h, w = shape
|
|
497
499
|
|
|
498
|
-
bbox_h_start = int((float(h) -
|
|
499
|
-
bbox_w_start = int((float(w) -
|
|
500
|
+
bbox_h_start = int((float(h) - float(h * central_fraction)) / 2)
|
|
501
|
+
bbox_w_start = int((float(w) - float(w * central_fraction)) / 2)
|
|
500
502
|
bbox_h_size = h - bbox_h_start * 2
|
|
501
503
|
bbox_w_size = w - bbox_w_start * 2
|
|
502
504
|
|
|
@@ -541,15 +543,14 @@ class CentralCrop(Cell):
|
|
|
541
543
|
def __init__(self, central_fraction):
|
|
542
544
|
super(CentralCrop, self).__init__()
|
|
543
545
|
validator.check_value_type("central_fraction", central_fraction, [float], self.cls_name)
|
|
544
|
-
validator.check_float_range(central_fraction, 0.0, 1.0,
|
|
546
|
+
validator.check_float_range(central_fraction, 0.0, 1.0, validator.INC_RIGHT, 'central_fraction', self.cls_name)
|
|
545
547
|
self.central_fraction = central_fraction
|
|
546
548
|
self.slice = P.Slice()
|
|
547
549
|
|
|
548
550
|
def construct(self, image):
|
|
549
551
|
image_shape = F.shape(image)
|
|
550
552
|
rank = len(image_shape)
|
|
551
|
-
|
|
552
|
-
return _raise_dims_rank_error(image_shape, "image", self.cls_name)
|
|
553
|
+
_check_rank(rank, image_shape, "image", self.cls_name)
|
|
553
554
|
if self.central_fraction == 1.0:
|
|
554
555
|
return image
|
|
555
556
|
|
|
@@ -561,63 +562,69 @@ class CentralCrop(Cell):
|
|
|
561
562
|
|
|
562
563
|
class PixelShuffle(Cell):
|
|
563
564
|
r"""
|
|
564
|
-
Applies
|
|
565
|
-
|
|
565
|
+
Applies the PixelShuffle operation over input which implements sub-pixel convolutions
|
|
566
|
+
with stride :math:`1/r` . For more details, refer to
|
|
566
567
|
`Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network
|
|
567
568
|
<https://arxiv.org/abs/1609.05158>`_ .
|
|
568
569
|
|
|
569
570
|
Typically, the input is of shape :math:`(*, C \times r^2, H, W)` , and the output is of shape
|
|
570
571
|
:math:`(*, C, H \times r, W \times r)`, where r is an upscale factor and * is zero or more batch dimensions.
|
|
571
572
|
|
|
573
|
+
Note:
|
|
574
|
+
The dimension of input Tensor on Ascend should be less than 7.
|
|
575
|
+
|
|
572
576
|
Args:
|
|
573
|
-
upscale_factor (int):
|
|
577
|
+
upscale_factor (int): factor to shuffle the input, and is a positive integer.
|
|
578
|
+
`upscale_factor` is the above-mentioned :math:`r`.
|
|
574
579
|
|
|
575
580
|
Inputs:
|
|
576
|
-
- **
|
|
577
|
-
the length of third to last dimension can be divisible by `upscale_factor` squared.
|
|
581
|
+
- **input** (Tensor) - Tensor of shape :math:`(*, C \times r^2, H, W)` . The dimension of `x` is larger than 2,
|
|
582
|
+
and the length of third to last dimension can be divisible by `upscale_factor` squared.
|
|
578
583
|
|
|
579
584
|
Outputs:
|
|
580
585
|
- **output** (Tensor) - Tensor of shape :math:`(*, C, H \times r, W \times r)` .
|
|
581
586
|
|
|
582
587
|
Raises:
|
|
583
588
|
ValueError: If `upscale_factor` is not a positive integer.
|
|
584
|
-
ValueError: If the length of third to last dimension of `
|
|
585
|
-
TypeError: If the dimension of `
|
|
589
|
+
ValueError: If the length of third to last dimension of `input` is not divisible by `upscale_factor` squared.
|
|
590
|
+
TypeError: If the dimension of `input` is less than 3.
|
|
586
591
|
|
|
587
592
|
Supported Platforms:
|
|
588
593
|
``Ascend`` ``GPU`` ``CPU``
|
|
589
594
|
|
|
590
595
|
Examples:
|
|
591
|
-
>>> input_x = np.arange(3 * 2 *
|
|
596
|
+
>>> input_x = np.arange(3 * 2 * 8 * 4 * 4).reshape((3, 2, 8, 4, 4))
|
|
592
597
|
>>> input_x = mindspore.Tensor(input_x, mindspore.dtype.int32)
|
|
593
|
-
>>> pixel_shuffle = nn.PixelShuffle(
|
|
598
|
+
>>> pixel_shuffle = nn.PixelShuffle(2)
|
|
594
599
|
>>> output = pixel_shuffle(input_x)
|
|
595
600
|
>>> print(output.shape)
|
|
596
|
-
(3, 2,
|
|
601
|
+
(3, 2, 2, 8, 8)
|
|
597
602
|
"""
|
|
598
603
|
def __init__(self, upscale_factor):
|
|
599
604
|
super(PixelShuffle, self).__init__()
|
|
600
605
|
self.upscale_factor = upscale_factor
|
|
601
606
|
|
|
602
|
-
def construct(self,
|
|
603
|
-
return ops.pixel_shuffle(
|
|
607
|
+
def construct(self, input):
|
|
608
|
+
return ops.pixel_shuffle(input, self.upscale_factor)
|
|
604
609
|
|
|
605
610
|
|
|
606
611
|
class PixelUnshuffle(Cell):
|
|
607
612
|
r"""
|
|
608
|
-
Applies
|
|
609
|
-
`Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network
|
|
613
|
+
Applies the PixelUnshuffle operation over input which is the inverse of PixelShuffle. For more details, refer
|
|
614
|
+
to `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network
|
|
610
615
|
<https://arxiv.org/abs/1609.05158>`_ .
|
|
611
616
|
|
|
612
617
|
Typically, the input is of shape :math:`(*, C, H \times r, W \times r)` , and the output is of shape
|
|
613
618
|
:math:`(*, C \times r^2, H, W)` , where r is a downscale factor and * is zero or more batch dimensions.
|
|
614
619
|
|
|
615
620
|
Args:
|
|
616
|
-
downscale_factor (int): factor to
|
|
621
|
+
downscale_factor (int): factor to unshuffle the input, and is a positive integer.
|
|
622
|
+
`downscale_factor` is the above-mentioned :math:`r`.
|
|
617
623
|
|
|
618
624
|
Inputs:
|
|
619
|
-
- **
|
|
620
|
-
2, and the length of second to last dimension or last dimension can be divisible by
|
|
625
|
+
- **input** (Tensor) - Tensor of shape :math:`(*, C, H \times r, W \times r)` . The dimension of `input` is
|
|
626
|
+
larger than 2, and the length of second to last dimension or last dimension can be divisible by
|
|
627
|
+
`downscale_factor` .
|
|
621
628
|
|
|
622
629
|
Outputs:
|
|
623
630
|
- **output** (Tensor) - Tensor of shape :math:`(*, C \times r^2, H, W)` .
|
|
@@ -625,22 +632,22 @@ class PixelUnshuffle(Cell):
|
|
|
625
632
|
Raises:
|
|
626
633
|
ValueError: If `downscale_factor` is not a positive integer.
|
|
627
634
|
ValueError: If the length of second to last dimension or last dimension is not divisible by `downscale_factor` .
|
|
628
|
-
TypeError: If the dimension of `
|
|
635
|
+
TypeError: If the dimension of `input` is less than 3.
|
|
629
636
|
|
|
630
637
|
Supported Platforms:
|
|
631
638
|
``Ascend`` ``GPU`` ``CPU``
|
|
632
639
|
|
|
633
640
|
Examples:
|
|
634
|
-
>>> pixel_unshuffle = nn.PixelUnshuffle(
|
|
635
|
-
>>> input_x = np.arange(
|
|
641
|
+
>>> pixel_unshuffle = nn.PixelUnshuffle(2)
|
|
642
|
+
>>> input_x = np.arange(8 * 8).reshape((1, 1, 8, 8))
|
|
636
643
|
>>> input_x = mindspore.Tensor(input_x, mindspore.dtype.int32)
|
|
637
644
|
>>> output = pixel_unshuffle(input_x)
|
|
638
645
|
>>> print(output.shape)
|
|
639
|
-
(1,
|
|
646
|
+
(1, 4, 4, 4)
|
|
640
647
|
"""
|
|
641
648
|
def __init__(self, downscale_factor):
|
|
642
649
|
super(PixelUnshuffle, self).__init__()
|
|
643
650
|
self.downscale_factor = downscale_factor
|
|
644
651
|
|
|
645
|
-
def construct(self,
|
|
646
|
-
return ops.pixel_unshuffle(
|
|
652
|
+
def construct(self, input):
|
|
653
|
+
return ops.pixel_unshuffle(input, self.downscale_factor)
|
mindspore/nn/layer/math.py
CHANGED
|
@@ -17,15 +17,15 @@ from __future__ import absolute_import
|
|
|
17
17
|
|
|
18
18
|
import numpy as np
|
|
19
19
|
|
|
20
|
+
from mindspore import log as logger
|
|
20
21
|
from mindspore.ops import operations as P
|
|
21
22
|
from mindspore.common.tensor import Tensor
|
|
22
23
|
from mindspore.common._decorator import deprecated
|
|
23
|
-
from mindspore.ops.primitive import constexpr
|
|
24
|
+
from mindspore.ops.primitive import constexpr, _primexpr
|
|
24
25
|
from mindspore.ops import functional as F
|
|
25
26
|
from mindspore.nn.cell import Cell
|
|
26
27
|
from mindspore.common import dtype as mstype
|
|
27
|
-
from mindspore
|
|
28
|
-
from mindspore.ops._utils.utils import is_shape_unknown
|
|
28
|
+
from mindspore import _checkparam as validator
|
|
29
29
|
|
|
30
30
|
__all__ = ['ReduceLogSumExp',
|
|
31
31
|
'Range',
|
|
@@ -33,6 +33,7 @@ __all__ = ['ReduceLogSumExp',
|
|
|
33
33
|
'DiGamma',
|
|
34
34
|
'IGamma',
|
|
35
35
|
'LBeta',
|
|
36
|
+
'CosineSimilarity',
|
|
36
37
|
'MatMul',
|
|
37
38
|
'Moments',
|
|
38
39
|
'MatInverse',
|
|
@@ -121,38 +122,15 @@ class ReduceLogSumExp(Cell):
|
|
|
121
122
|
|
|
122
123
|
class Range(Cell):
|
|
123
124
|
r"""
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
The size of output is :math:`\left \lfloor \frac{limit-start}{delta} \right \rfloor + 1` and `delta` is the gap
|
|
127
|
-
between two values in the tensor.
|
|
128
|
-
|
|
129
|
-
.. math::
|
|
130
|
-
|
|
131
|
-
out_{i+1} = out_{i} +delta
|
|
132
|
-
|
|
133
|
-
Args:
|
|
134
|
-
start (Union[int, float]): If `limit` is `None`, the value acts as limit in the range and first entry
|
|
135
|
-
defaults to `0`. Otherwise, it acts as first entry in the range.
|
|
136
|
-
limit (Union[int, float]): Acts as upper limit of sequence. If `None`, defaults to the value of `start`
|
|
137
|
-
while set the first entry of the range to `0`. It can not be equal to `start`. Default: None.
|
|
138
|
-
delta (Union[int, float]): Increment of the range. It can not be equal to zero. Default: 1.
|
|
139
|
-
|
|
140
|
-
Outputs:
|
|
141
|
-
Tensor, the dtype is int if the dtype of `start`, `limit` and `delta` all are int. Otherwise, dtype is float.
|
|
142
|
-
|
|
143
|
-
Supported Platforms:
|
|
144
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
145
|
-
|
|
146
|
-
Examples:
|
|
147
|
-
>>> net = nn.Range(1, 8, 2)
|
|
148
|
-
>>> output = net()
|
|
149
|
-
>>> print(output)
|
|
150
|
-
[1 3 5 7]
|
|
125
|
+
'nn.Range' is deprecated from version 2.0 and will be removed in a future version,
|
|
126
|
+
use 'ops.range' instead.
|
|
151
127
|
"""
|
|
152
128
|
|
|
153
129
|
def __init__(self, start, limit=None, delta=1):
|
|
154
130
|
"""Initialize Range."""
|
|
155
131
|
super(Range, self).__init__()
|
|
132
|
+
logger.warning("'nn.Range' is deprecated from version 2.0 and will be removed in a future version,"
|
|
133
|
+
"use 'ops.range' instead.")
|
|
156
134
|
if delta == 0:
|
|
157
135
|
raise ValueError(f"For '{self.cls_name}', the 'delta' can not be zero.")
|
|
158
136
|
data = np.arange(start, limit, delta)
|
|
@@ -256,7 +234,7 @@ class LGamma(Cell):
|
|
|
256
234
|
def construct(self, x):
|
|
257
235
|
input_dtype = self.dtype(x)
|
|
258
236
|
_check_input_dtype("x", input_dtype, [mstype.float16, mstype.float32], self.cls_name)
|
|
259
|
-
if
|
|
237
|
+
if F.is_sequence_value_unknown(self.shape(x)):
|
|
260
238
|
infinity = self.ones_like(x) * F.cast(self.inf, input_dtype)
|
|
261
239
|
else:
|
|
262
240
|
infinity = self.fill(input_dtype, self.shape(x), self.inf)
|
|
@@ -590,7 +568,7 @@ class IGamma(Cell):
|
|
|
590
568
|
or if x has different dtype with a.
|
|
591
569
|
|
|
592
570
|
Supported Platforms:
|
|
593
|
-
``Ascend`` ``GPU``
|
|
571
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
594
572
|
|
|
595
573
|
Examples:
|
|
596
574
|
>>> a = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32))
|
|
@@ -636,9 +614,8 @@ class IGamma(Cell):
|
|
|
636
614
|
ax = a * self.log(x) - x - self.lgamma(a)
|
|
637
615
|
para_shape = self.shape(ax)
|
|
638
616
|
if para_shape != ():
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
a = broadcastto(a)
|
|
617
|
+
x = F.broadcast_to(x, para_shape)
|
|
618
|
+
a = F.broadcast_to(a, para_shape)
|
|
642
619
|
x_is_zero = self.equal(x, 0)
|
|
643
620
|
log_maxfloat = self.log_maxfloat32
|
|
644
621
|
underflow = self.less(ax, self.neg(log_maxfloat))
|
|
@@ -721,9 +698,8 @@ class LBeta(Cell):
|
|
|
721
698
|
x_plus_y = x + y
|
|
722
699
|
para_shape = self.shape(x_plus_y)
|
|
723
700
|
if para_shape != ():
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
y = broadcastto(y)
|
|
701
|
+
x = F.broadcast_to(x, para_shape)
|
|
702
|
+
y = F.broadcast_to(y, para_shape)
|
|
727
703
|
comp_less = self.less(x, y)
|
|
728
704
|
x_min = self.select(comp_less, x, y)
|
|
729
705
|
y_max = self.select(comp_less, y, x)
|
|
@@ -761,14 +737,17 @@ class LBeta(Cell):
|
|
|
761
737
|
return self.select(comp_xless8, temp, log_beta_two_large)
|
|
762
738
|
|
|
763
739
|
|
|
764
|
-
@
|
|
740
|
+
@_primexpr
|
|
765
741
|
def get_broadcast_matmul_shape(x_shape, y_shape, prim_name=None):
|
|
766
742
|
"""get broadcast_matmul shape"""
|
|
767
743
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
744
|
+
|
|
745
|
+
def _check_len():
|
|
746
|
+
if (len(x_shape) < 2) or (len(y_shape) < 2):
|
|
747
|
+
raise ValueError(f"{msg_prefix} length of 'x_shape' and 'y_shape' must be equal to or greater than 2, "
|
|
748
|
+
f"but got the length of 'x_shape': {len(x_shape)} and the length of 'y_shape': "
|
|
749
|
+
f"{len(y_shape)}.")
|
|
750
|
+
_check_len()
|
|
772
751
|
x_shape_batch = x_shape[:-2]
|
|
773
752
|
y_shape_batch = y_shape[:-2]
|
|
774
753
|
if x_shape_batch == y_shape_batch:
|
|
@@ -777,17 +756,21 @@ def get_broadcast_matmul_shape(x_shape, y_shape, prim_name=None):
|
|
|
777
756
|
y_len = len(y_shape)
|
|
778
757
|
length = x_len if x_len < y_len else y_len
|
|
779
758
|
broadcast_shape_back = []
|
|
759
|
+
|
|
760
|
+
def _check_broadcast(x_val, y_val, i):
|
|
761
|
+
if not (x_val == 1 or y_val == 1 or x_val == y_val):
|
|
762
|
+
raise ValueError(f"{msg_prefix} 'x_shape[{i}]' must be equal to 1, or the 'y_shape[{i}]' must be equal "
|
|
763
|
+
f"to 1, or the 'x_shape[{i}]' must be equal to 'y_shape[{i}]', but got "
|
|
764
|
+
f"'x_shape[{i}]': {x_val}, 'y_shape[{i}]': {y_val}.")
|
|
765
|
+
|
|
780
766
|
for i in range(-length, -2):
|
|
767
|
+
_check_broadcast(x_shape[i], y_shape[i], i)
|
|
781
768
|
if x_shape[i] == 1:
|
|
782
769
|
broadcast_shape_back.append(y_shape[i])
|
|
783
770
|
elif y_shape[i] == 1:
|
|
784
771
|
broadcast_shape_back.append(x_shape[i])
|
|
785
|
-
elif x_shape[i] == y_shape[i]:
|
|
786
|
-
broadcast_shape_back.append(x_shape[i])
|
|
787
772
|
else:
|
|
788
|
-
|
|
789
|
-
f"to 1, or the 'x_shape[{i}]' must be equal to 'y_shape[{i}]', but got "
|
|
790
|
-
f"'x_shape[{i}]': {x_shape[i]}, 'y_shape[{i}]': {y_shape[i]}.")
|
|
773
|
+
broadcast_shape_back.append(x_shape[i])
|
|
791
774
|
|
|
792
775
|
broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
|
|
793
776
|
x_broadcast_shape = broadcast_shape_front + tuple(broadcast_shape_back) + x_shape[-2:]
|
|
@@ -795,10 +778,15 @@ def get_broadcast_matmul_shape(x_shape, y_shape, prim_name=None):
|
|
|
795
778
|
return x_broadcast_shape, y_broadcast_shape
|
|
796
779
|
|
|
797
780
|
|
|
798
|
-
@
|
|
781
|
+
@_primexpr
|
|
799
782
|
def check_col_row_equal(x1_shape, x2_shape, transpose_x1, transpose_x2, prim_name=None):
|
|
800
783
|
"""check col and row equal"""
|
|
801
|
-
|
|
784
|
+
def _check(x1_col, x2_row):
|
|
785
|
+
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
786
|
+
if x1_col != x2_row:
|
|
787
|
+
raise ValueError(f"{msg_prefix} column of matrix dimensions of 'x1' must be equal to "
|
|
788
|
+
f"the row of matrix dimensions of 'x2', but got 'x1_col' {x1_col} and 'x2_row' {x2_row}.")
|
|
789
|
+
|
|
802
790
|
if len(x1_shape) == 1:
|
|
803
791
|
transpose_x1 = False
|
|
804
792
|
x1_shape = (1,) + x1_shape
|
|
@@ -809,9 +797,7 @@ def check_col_row_equal(x1_shape, x2_shape, transpose_x1, transpose_x2, prim_nam
|
|
|
809
797
|
x2_last = x2_shape[-2:]
|
|
810
798
|
x1_col = x1_last[not transpose_x1] # x1_col = x1_last[1] if (not transpose_a) else x1_last[0]
|
|
811
799
|
x2_row = x2_last[transpose_x2] # x2_row = x2_last[0] if (not transpose_b) else x2_last[1]
|
|
812
|
-
|
|
813
|
-
raise ValueError(f"{msg_prefix} column of matrix dimensions of 'x1' must be equal to "
|
|
814
|
-
f"the row of matrix dimensions of 'x2', but got 'x1_col' {x1_col} and 'x2_row' {x2_row}.")
|
|
800
|
+
_check(x1_col, x2_row)
|
|
815
801
|
|
|
816
802
|
|
|
817
803
|
def matmul_op_select(x1_shape, x2_shape, transpose_x1, transpose_x2):
|
|
@@ -872,12 +858,10 @@ class MatMul(Cell):
|
|
|
872
858
|
x2_shape = self.shape_op(x2)
|
|
873
859
|
|
|
874
860
|
x1_broadcast_shape, x2_broadcast_shape = get_broadcast_matmul_shape(x1_shape, x2_shape)
|
|
875
|
-
x1_broadcast_to = P.BroadcastTo(x1_broadcast_shape)
|
|
876
|
-
x2_broadcast_to = P.BroadcastTo(x2_broadcast_shape)
|
|
877
861
|
if x1_broadcast_shape != x1_shape:
|
|
878
|
-
x1 =
|
|
862
|
+
x1 = F.broadcast_to(x1, x1_broadcast_shape)
|
|
879
863
|
if x2_broadcast_shape != x2_shape:
|
|
880
|
-
x2 =
|
|
864
|
+
x2 = F.broadcast_to(x2, x2_broadcast_shape)
|
|
881
865
|
|
|
882
866
|
matmul_broadcast = matmul_op(x1, x2)
|
|
883
867
|
|
|
@@ -889,72 +873,80 @@ class MatMul(Cell):
|
|
|
889
873
|
return matmul_broadcast
|
|
890
874
|
|
|
891
875
|
|
|
892
|
-
class
|
|
893
|
-
"""
|
|
894
|
-
|
|
876
|
+
class CosineSimilarity(Cell):
|
|
877
|
+
r"""
|
|
878
|
+
Computes cosine similarity.
|
|
879
|
+
|
|
880
|
+
.. math::
|
|
881
|
+
\mathcal{K} = \frac{\textbf{x}\textbf{y}^{\top}}{\parallel \textbf{x} \parallel \parallel \textbf{y} \parallel},
|
|
882
|
+
|
|
883
|
+
where :math:`\mathcal{K}` is the similarity, :math:`\textbf{x}` is the first tensor `x1`,
|
|
884
|
+
:math:`\textbf{y}` is the second tensor `x2`.
|
|
885
|
+
|
|
886
|
+
To avoid numerical errors when dividing by small numbers,
|
|
887
|
+
the lower bound of :math:`\parallel \textbf{x} \parallel \parallel \textbf{y} \parallel` is set to `eps`.
|
|
895
888
|
|
|
896
889
|
Args:
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
keep_dims (Union[bool, None]): If True, the calculation result will retain the dimension of `axis`,
|
|
900
|
-
and the dimensions of the mean and variance are the same as the input. If False or None,
|
|
901
|
-
the dimension of `axis` will be reduced. Default: None.
|
|
890
|
+
dim (int, optional): Dimension. Default: 1.
|
|
891
|
+
eps (float, optional): Small value. Default: 1e-08.
|
|
902
892
|
|
|
903
893
|
Inputs:
|
|
904
|
-
- **
|
|
905
|
-
|
|
894
|
+
- **x1** (Tensor) - The first tensor :math:`\textbf{x}`.
|
|
895
|
+
Shape: :math:`(\ast_1, D, \ast_2)` where :math:`D` is at position `dim`.
|
|
896
|
+
- **x2** (Tensor) - The second tensor :math:`\textbf{y}`. The shape is the same as `x1`.
|
|
906
897
|
|
|
907
898
|
Outputs:
|
|
908
|
-
|
|
909
|
-
- **variance** (Tensor) - The variance of `x` on `axis`, with the same data type as input `x`.
|
|
899
|
+
Tensor, with shape :math:`(\ast_1, \ast_2)`, the data type will be inferred automatically.
|
|
910
900
|
|
|
911
901
|
Raises:
|
|
912
|
-
TypeError: If `
|
|
913
|
-
TypeError: If `keep_dims` is neither bool nor None.
|
|
914
|
-
TypeError: If dtype of `x` is neither float16 nor float32.
|
|
902
|
+
TypeError: If `x1` or `x2` is not a Tensor.
|
|
915
903
|
|
|
916
904
|
Supported Platforms:
|
|
917
905
|
``Ascend`` ``GPU`` ``CPU``
|
|
918
906
|
|
|
919
907
|
Examples:
|
|
920
|
-
>>>
|
|
921
|
-
>>>
|
|
922
|
-
>>>
|
|
923
|
-
>>>
|
|
924
|
-
>>> print(
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
908
|
+
>>> x1 = Tensor([[1.0, 3.0, 4.0, 7.0], [2.0, 4.0, 2.0, 5.0], [3.0, 1.0, 5.0, 8.0]])
|
|
909
|
+
>>> x2 = Tensor([[2.0, 4.0, 2.0, 5.0], [3.0, 1.0, 5.0, 8.0], [1.0, 3.0, 4.0, 7.0]])
|
|
910
|
+
>>> func = nn.layer.CosineSimilarity()
|
|
911
|
+
>>> out = func(x1, x2)
|
|
912
|
+
>>> print(out.asnumpy())
|
|
913
|
+
[0.9402562 0.8614609 0.9516245]
|
|
914
|
+
"""
|
|
915
|
+
|
|
916
|
+
def __init__(self, dim=1, eps=1e-08):
|
|
917
|
+
"""Initialize CosineSimilarity."""
|
|
918
|
+
super().__init__()
|
|
919
|
+
self.dim = dim
|
|
920
|
+
self.eps = eps
|
|
921
|
+
self.mul = P.Mul()
|
|
922
|
+
self.div = P.DivNoNan()
|
|
923
|
+
self.maximum = P.Maximum()
|
|
924
|
+
self.cast = P.Cast()
|
|
925
|
+
|
|
926
|
+
def construct(self, x1, x2):
|
|
927
|
+
if not isinstance(x1, Tensor):
|
|
928
|
+
raise TypeError(f"For 'CosineSimilarity', 'x1' must be a tensor, but got {type(x1)}")
|
|
929
|
+
if not isinstance(x2, Tensor):
|
|
930
|
+
raise TypeError(f"For 'CosineSimilarity', 'x2' must be a tensor, but got {type(x2)}")
|
|
931
|
+
w12 = self.mul(x1, x2).sum(self.dim)
|
|
932
|
+
w1 = self.mul(x1, x1).sum(self.dim)
|
|
933
|
+
w2 = self.mul(x2, x2).sum(self.dim)
|
|
934
|
+
n12 = self.maximum(self.mul(w1, w2), self.eps * self.eps).sqrt()
|
|
935
|
+
out = self.div(w12, n12)
|
|
936
|
+
return out
|
|
937
|
+
|
|
938
|
+
|
|
939
|
+
class Moments(Cell):
|
|
940
|
+
"""
|
|
941
|
+
'nn.Moments' is deprecated from version 2.0 and will be removed in a future version,
|
|
942
|
+
use 'ops.var_mean' instead.
|
|
953
943
|
"""
|
|
954
944
|
|
|
955
945
|
def __init__(self, axis=None, keep_dims=None):
|
|
956
946
|
"""Initialize Moments."""
|
|
957
947
|
super(Moments, self).__init__()
|
|
948
|
+
logger.warning("'nn.Moments' is deprecated from version 2.0 and will be removed in a future version,"
|
|
949
|
+
"use 'ops.var_mean' instead.")
|
|
958
950
|
if axis is None:
|
|
959
951
|
axis = ()
|
|
960
952
|
if isinstance(axis, tuple):
|