mindspore 2.0.0a0__cp37-cp37m-win_amd64.whl → 2.0.0rc1__cp37-cp37m-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -2
- mindspore/_c_dataengine.cp37-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp37-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp37-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +16 -1
- mindspore/_extends/parse/parser.py +107 -22
- mindspore/_extends/parse/resources.py +0 -7
- mindspore/_extends/parse/standard_method.py +885 -413
- mindspore/amp.py +52 -57
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +38 -20
- mindspore/boost/dim_reduce.py +3 -3
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -6
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +41 -7
- mindspore/common/api.py +215 -141
- mindspore/common/dtype.py +8 -1
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +4 -2
- mindspore/common/jit_config.py +17 -13
- mindspore/common/mutable.py +33 -13
- mindspore/common/parameter.py +23 -21
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +62 -41
- mindspore/common/tensor.py +852 -1154
- mindspore/communication/__init__.py +2 -2
- mindspore/communication/_comm_helper.py +11 -4
- mindspore/communication/management.py +22 -21
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +201 -23
- mindspore/dataset/__init__.py +6 -6
- mindspore/dataset/audio/__init__.py +7 -7
- mindspore/dataset/audio/transforms.py +670 -30
- mindspore/dataset/audio/utils.py +47 -4
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +210 -14
- mindspore/dataset/core/validator_helpers.py +2 -2
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +322 -66
- mindspore/dataset/engine/datasets_audio.py +80 -76
- mindspore/dataset/engine/datasets_standard_format.py +51 -38
- mindspore/dataset/engine/datasets_text.py +232 -118
- mindspore/dataset/engine/datasets_user_defined.py +41 -17
- mindspore/dataset/engine/datasets_vision.py +746 -225
- mindspore/dataset/engine/graphdata.py +75 -10
- mindspore/dataset/engine/iterators.py +45 -5
- mindspore/dataset/engine/offload.py +48 -28
- mindspore/dataset/engine/validators.py +117 -8
- mindspore/dataset/text/__init__.py +6 -5
- mindspore/dataset/text/transforms.py +86 -3
- mindspore/dataset/text/utils.py +6 -4
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +3 -2
- mindspore/dataset/transforms/c_transforms.py +1 -1
- mindspore/dataset/transforms/transforms.py +2 -2
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +2 -3
- mindspore/dataset/vision/c_transforms.py +9 -9
- mindspore/dataset/vision/py_transforms.py +5 -5
- mindspore/dataset/vision/py_transforms_util.py +2 -0
- mindspore/dataset/vision/transforms.py +160 -161
- mindspore/dataset/vision/utils.py +3 -3
- mindspore/experimental/map_parameter.py +38 -26
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +44 -9
- mindspore/include/api/delegate.h +1 -1
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_parallel_runner.h +2 -2
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +19 -3
- mindspore/include/api/types.h +3 -3
- mindspore/include/dataset/constants.h +7 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filereader.py +18 -0
- mindspore/mindrecord/filewriter.py +197 -34
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
- mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
- mindspore/mindrecord/tools/csv_to_mr.py +3 -3
- mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/__init__.py +0 -4
- mindspore/nn/cell.py +204 -132
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +7 -6
- mindspore/nn/layer/__init__.py +5 -4
- mindspore/nn/layer/activation.py +40 -89
- mindspore/nn/layer/basic.py +255 -624
- mindspore/nn/layer/channel_shuffle.py +7 -6
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +41 -4
- mindspore/nn/layer/conv.py +64 -28
- mindspore/nn/layer/dense.py +9 -8
- mindspore/nn/layer/embedding.py +27 -25
- mindspore/nn/layer/image.py +53 -46
- mindspore/nn/layer/math.py +97 -105
- mindspore/nn/layer/normalization.py +117 -86
- mindspore/nn/layer/padding.py +185 -95
- mindspore/nn/layer/pooling.py +817 -414
- mindspore/nn/layer/rnn_cells.py +10 -15
- mindspore/nn/layer/rnns.py +37 -38
- mindspore/nn/layer/thor_layer.py +11 -12
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +5 -4
- mindspore/nn/loss/loss.py +334 -199
- mindspore/nn/optim/ada_grad.py +6 -6
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +4 -5
- mindspore/nn/optim/adam.py +126 -62
- mindspore/nn/optim/adamax.py +3 -4
- mindspore/nn/optim/adasum.py +6 -6
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +67 -38
- mindspore/nn/optim/lamb.py +4 -5
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +43 -4
- mindspore/nn/optim/momentum.py +6 -5
- mindspore/nn/optim/optimizer.py +3 -1
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +1 -1
- mindspore/nn/optim/rprop.py +8 -9
- mindspore/nn/optim/sgd.py +19 -13
- mindspore/nn/optim/thor.py +10 -15
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +4 -4
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +9 -15
- mindspore/nn/probability/distribution/bernoulli.py +3 -3
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +5 -7
- mindspore/nn/probability/distribution/cauchy.py +3 -3
- mindspore/nn/probability/distribution/distribution.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +3 -3
- mindspore/nn/probability/distribution/half_normal.py +15 -11
- mindspore/nn/probability/distribution/laplace.py +16 -13
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/normal.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/student_t.py +20 -15
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +27 -10
- mindspore/nn/wrap/grad_reducer.py +2 -2
- mindspore/nn/wrap/loss_scale.py +40 -24
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +35 -30
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +22 -19
- mindspore/numpy/utils.py +1 -1
- mindspore/numpy/utils_const.py +108 -58
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +86 -117
- mindspore/ops/_grad/grad_base.py +23 -1
- mindspore/ops/_grad/grad_clip_ops.py +2 -3
- mindspore/ops/_grad/grad_comm_ops.py +34 -24
- mindspore/ops/_grad/grad_implementations.py +9 -45
- mindspore/ops/_grad/grad_inner_ops.py +47 -4
- mindspore/ops/_grad/grad_math_ops.py +142 -117
- mindspore/ops/_grad/grad_nn_ops.py +71 -165
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +7 -6
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
- mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
- mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
- mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -611
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_register_for_op.py +1 -0
- mindspore/ops/_utils/__init__.py +1 -2
- mindspore/ops/_utils/utils.py +19 -40
- mindspore/ops/_vmap/vmap_array_ops.py +116 -38
- mindspore/ops/_vmap/vmap_base.py +16 -9
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
- mindspore/ops/_vmap/vmap_image_ops.py +12 -5
- mindspore/ops/_vmap/vmap_math_ops.py +46 -5
- mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
- mindspore/ops/_vmap/vmap_random_ops.py +1 -1
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
- mindspore/ops/composite/__init__.py +7 -8
- mindspore/ops/composite/base.py +101 -47
- mindspore/ops/composite/math_ops.py +188 -158
- mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
- mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
- mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
- mindspore/ops/function/__init__.py +152 -8
- mindspore/ops/function/array_func.py +2555 -674
- mindspore/ops/function/clip_func.py +209 -13
- mindspore/ops/function/debug_func.py +2 -2
- mindspore/ops/function/grad/__init__.py +2 -1
- mindspore/ops/function/grad/grad_func.py +147 -62
- mindspore/ops/function/image_func.py +54 -38
- mindspore/ops/function/linalg_func.py +167 -16
- mindspore/ops/function/math_func.py +4849 -1492
- mindspore/ops/function/nn_func.py +2573 -988
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +3 -3
- mindspore/ops/function/random_func.py +790 -73
- mindspore/ops/function/sparse_func.py +98 -78
- mindspore/ops/function/sparse_unary_func.py +54 -53
- mindspore/ops/function/spectral_func.py +27 -24
- mindspore/ops/function/vmap_func.py +22 -2
- mindspore/ops/functional.py +97 -37
- mindspore/ops/op_info_register.py +70 -28
- mindspore/ops/operations/__init__.py +47 -14
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +5 -5
- mindspore/ops/operations/_grad_ops.py +276 -187
- mindspore/ops/operations/_inner_ops.py +319 -113
- mindspore/ops/operations/_ms_kernel.py +10 -8
- mindspore/ops/operations/_ocr_ops.py +9 -9
- mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
- mindspore/ops/operations/_quant_ops.py +137 -102
- mindspore/ops/operations/_rl_inner_ops.py +121 -60
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1004 -2
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +801 -466
- mindspore/ops/operations/comm_ops.py +51 -49
- mindspore/ops/operations/control_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +123 -44
- mindspore/ops/operations/debug_ops.py +24 -24
- mindspore/ops/operations/image_ops.py +240 -153
- mindspore/ops/operations/inner_ops.py +34 -50
- mindspore/ops/operations/linalg_ops.py +31 -9
- mindspore/ops/operations/math_ops.py +988 -757
- mindspore/ops/operations/nn_ops.py +965 -819
- mindspore/ops/operations/other_ops.py +51 -40
- mindspore/ops/operations/random_ops.py +204 -122
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +254 -93
- mindspore/ops/operations/spectral_ops.py +35 -3
- mindspore/ops/primitive.py +111 -9
- mindspore/parallel/_auto_parallel_context.py +189 -83
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +99 -7
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +7 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
- mindspore/parallel/_utils.py +1 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +37 -34
- mindspore/parallel/shard.py +17 -18
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +69 -47
- mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
- mindspore/profiler/parser/base_timeline_generator.py +49 -56
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
- mindspore/profiler/parser/hwts_log_parser.py +1 -1
- mindspore/profiler/parser/integrator.py +15 -14
- mindspore/profiler/parser/minddata_analyzer.py +2 -2
- mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +2 -1
- mindspore/profiler/profiling.py +218 -186
- mindspore/rewrite/__init__.py +3 -1
- mindspore/rewrite/api/node.py +1 -114
- mindspore/rewrite/api/node_type.py +3 -0
- mindspore/rewrite/api/pattern_engine.py +31 -1
- mindspore/rewrite/api/scoped_value.py +4 -4
- mindspore/rewrite/api/symbol_tree.py +3 -78
- mindspore/rewrite/api/tree_node_helper.py +1 -1
- mindspore/rewrite/ast_creator_register.py +1 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -2
- mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
- mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
- mindspore/rewrite/namespace.py +0 -2
- mindspore/rewrite/node.py +157 -11
- mindspore/rewrite/parsers/assign_parser.py +231 -53
- mindspore/rewrite/parsers/class_def_parser.py +187 -109
- mindspore/rewrite/parsers/for_parser.py +24 -14
- mindspore/rewrite/parsers/function_def_parser.py +21 -4
- mindspore/rewrite/parsers/if_parser.py +6 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +256 -133
- mindspore/rewrite/symbol_tree_builder.py +38 -1
- mindspore/run_check/_check_version.py +69 -63
- mindspore/run_check/run_check.py +2 -1
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +1 -1
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +273 -102
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +2 -2
- mindspore/train/callback/_checkpoint.py +3 -3
- mindspore/train/callback/_early_stop.py +3 -3
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +29 -31
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +3 -3
- mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
- mindspore/train/callback/_summary_collector.py +23 -16
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +15 -3
- mindspore/train/dataset_helper.py +10 -15
- mindspore/train/loss_scale_manager.py +8 -11
- mindspore/train/metrics/__init__.py +1 -1
- mindspore/train/metrics/bleu_score.py +1 -1
- mindspore/train/metrics/confusion_matrix.py +1 -1
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/dice.py +2 -2
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +4 -3
- mindspore/train/metrics/mean_surface_distance.py +2 -2
- mindspore/train/metrics/occlusion_sensitivity.py +1 -1
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +1 -1
- mindspore/train/metrics/recall.py +1 -1
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +45 -28
- mindspore/train/serialization.py +295 -188
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -13
- mindspore/train/train_thor/convert_utils.py +2 -2
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
- mindspore/compression/__init__.py +0 -19
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -515
- mindspore/compression/quant/__init__.py +0 -28
- mindspore/compression/quant/qat.py +0 -634
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -140
- mindspore/nn/probability/dpn/vae/vae.py +0 -124
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
- mindspore/ops/composite/array_ops.py +0 -241
- mindspore/ops/composite/clip_ops.py +0 -134
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -15,7 +15,6 @@
|
|
|
15
15
|
|
|
16
16
|
"""Define the grad rules of math related operations."""
|
|
17
17
|
|
|
18
|
-
from functools import reduce
|
|
19
18
|
import numpy as np
|
|
20
19
|
import mindspore as ms
|
|
21
20
|
from mindspore import nn
|
|
@@ -30,10 +29,9 @@ from mindspore.ops._grad.grad_base import bprop_getters, create_tensor_by_elemen
|
|
|
30
29
|
from mindspore.ops._grad.grad_base import convert_to_tensor
|
|
31
30
|
from mindspore.ops._grad.grad_base import sum_grad_reduce_axis, dyn_fill, dyn_rank
|
|
32
31
|
from mindspore.ops._grad.grad_base import dyn_ones, dyn_rank_1d
|
|
33
|
-
from mindspore.ops.primitive import
|
|
32
|
+
from mindspore.ops.primitive import _primexpr
|
|
34
33
|
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
|
|
35
|
-
from mindspore.ops.operations._inner_ops import DynamicBroadcastGradientArgs,
|
|
36
|
-
from mindspore.ops._utils.utils import is_shape_unknown, is_dim_unknown
|
|
34
|
+
from mindspore.ops.operations._inner_ops import DynamicBroadcastGradientArgs, IsSubClass, DynamicBroadcastTo
|
|
37
35
|
from mindspore.ops.operations import array_ops as A
|
|
38
36
|
|
|
39
37
|
shape_op = P.Shape()
|
|
@@ -114,7 +112,7 @@ def binop_grad_common(x, y, dx, dy):
|
|
|
114
112
|
# if input shape is the same as dout shape, do not need to reduce
|
|
115
113
|
reduce_dx = dx
|
|
116
114
|
reduce_dy = dy
|
|
117
|
-
if not (
|
|
115
|
+
if not (F.is_sequence_value_unknown(shape_of_x) or F.is_sequence_value_unknown(shape_of_y)):
|
|
118
116
|
rx = broadcast_gradient_args(shape_of_x, shape_of_y)
|
|
119
117
|
if rx[0]:
|
|
120
118
|
# if dx is scalar whose shape is (), do not need reduce
|
|
@@ -127,11 +125,12 @@ def binop_grad_common(x, y, dx, dy):
|
|
|
127
125
|
dy = _reduce_sum_with_cast(dy, rx[1])
|
|
128
126
|
reduce_dy = reshape(dy, shape_of_y)
|
|
129
127
|
return reduce_dx, reduce_dy
|
|
130
|
-
|
|
128
|
+
|
|
129
|
+
if not isinstance(shape_of_x, tuple) or not isinstance(shape_of_y, tuple):
|
|
131
130
|
# x or y is scalar
|
|
132
|
-
if not shape_of_x:
|
|
131
|
+
if not isinstance(shape_of_x, tuple):
|
|
133
132
|
reduce_dx = _reduce_sum_with_cast(dx, ())
|
|
134
|
-
if not shape_of_y:
|
|
133
|
+
if not isinstance(shape_of_y, tuple):
|
|
135
134
|
reduce_dy = _reduce_sum_with_cast(dy, ())
|
|
136
135
|
return reduce_dx, reduce_dy
|
|
137
136
|
|
|
@@ -151,7 +150,7 @@ def binop_grad_common_with_shift(x, y, dx, dy, shift):
|
|
|
151
150
|
# if input shape is the same as dout shape, do not need to reduce
|
|
152
151
|
reduce_dx = dx
|
|
153
152
|
reduce_dy = dy
|
|
154
|
-
if not (
|
|
153
|
+
if not (F.is_sequence_value_unknown(broadcast_shape_of_x) or F.is_sequence_value_unknown(broadcast_shape_of_y)):
|
|
155
154
|
rx = broadcast_gradient_args(broadcast_shape_of_x, broadcast_shape_of_y)
|
|
156
155
|
if rx[0]:
|
|
157
156
|
# if dx is scalar whose shape is (), do not need reduce
|
|
@@ -164,11 +163,12 @@ def binop_grad_common_with_shift(x, y, dx, dy, shift):
|
|
|
164
163
|
dy = _reduce_sum_with_cast(dy, rx[1])
|
|
165
164
|
reduce_dy = reshape(dy, shape_of_y)
|
|
166
165
|
return reduce_dx, reduce_dy
|
|
167
|
-
|
|
166
|
+
|
|
167
|
+
if not isinstance(shape_of_x, tuple) or not isinstance(shape_of_y, tuple):
|
|
168
168
|
# x or y is scalar
|
|
169
|
-
if not shape_of_x:
|
|
169
|
+
if not isinstance(shape_of_x, tuple):
|
|
170
170
|
reduce_dx = _reduce_sum_with_cast(dx, ())
|
|
171
|
-
if not shape_of_y:
|
|
171
|
+
if not isinstance(shape_of_y, tuple):
|
|
172
172
|
reduce_dy = _reduce_sum_with_cast(dy, ())
|
|
173
173
|
return reduce_dx, reduce_dy
|
|
174
174
|
|
|
@@ -178,7 +178,7 @@ def binop_grad_common_with_shift(x, y, dx, dy, shift):
|
|
|
178
178
|
def _dyn_reduced_shape(input_shape, axis, x):
|
|
179
179
|
"""Dynamic reduce shape"""
|
|
180
180
|
input_shape = P.Cast()(input_shape, ms.int32)
|
|
181
|
-
if x is not None and not
|
|
181
|
+
if x is not None and not F.is_sequence_shape_unknown(shape_op(x)):
|
|
182
182
|
input_rank = len(shape_op(x))
|
|
183
183
|
else:
|
|
184
184
|
input_rank = dyn_rank(x)
|
|
@@ -209,7 +209,7 @@ def _sum_grad(x, axis, dout):
|
|
|
209
209
|
"""Grad definition for `Sum` operation."""
|
|
210
210
|
input_shape = shape_op(x)
|
|
211
211
|
is_mutable, axis = convert_to_tensor(axis)
|
|
212
|
-
if
|
|
212
|
+
if F.is_sequence_value_unknown(input_shape) or is_mutable:
|
|
213
213
|
input_shape = dyn_shape_op(x)
|
|
214
214
|
output_shape_kept_dims = _dyn_reduced_shape(input_shape, axis, x)
|
|
215
215
|
output_shape_kept_dims = P.Cast()(output_shape_kept_dims, ms.int32)
|
|
@@ -226,7 +226,7 @@ def _min_or_max_grad(x, axis, out, dout):
|
|
|
226
226
|
"""Grad definition for `Min` and `Max` operations."""
|
|
227
227
|
input_shape = shape_op(x)
|
|
228
228
|
output_shape_kept_dims = ()
|
|
229
|
-
if
|
|
229
|
+
if F.is_sequence_value_unknown(input_shape):
|
|
230
230
|
input_shape = dyn_shape_op(x)
|
|
231
231
|
output_shape_kept_dims = _dyn_reduced_shape(input_shape, axis, x)
|
|
232
232
|
output_shape_kept_dims = P.Cast()(output_shape_kept_dims, ms.int32)
|
|
@@ -268,7 +268,7 @@ def _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout):
|
|
|
268
268
|
x_axis = axis
|
|
269
269
|
onehot_axis_is_neg = False
|
|
270
270
|
if x_axis < 0:
|
|
271
|
-
if not
|
|
271
|
+
if not F.is_sequence_shape_unknown(x_shape):
|
|
272
272
|
x_axis = axis + x_dim
|
|
273
273
|
else:
|
|
274
274
|
onehot_axis_is_neg = True
|
|
@@ -279,13 +279,13 @@ def _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout):
|
|
|
279
279
|
else:
|
|
280
280
|
dout_expand = expand(dout[1], onehot_axis)
|
|
281
281
|
out_shape = shape_op(out[0])
|
|
282
|
-
if not
|
|
282
|
+
if not F.is_sequence_shape_unknown(out_shape):
|
|
283
283
|
if onehot_axis >= len(out_shape):
|
|
284
284
|
onehot_axis = -1
|
|
285
285
|
type_x = F.dtype(x)
|
|
286
286
|
on_value = F.cast(F.scalar_to_tensor(1.0), type_x)
|
|
287
287
|
off_value = F.cast(F.scalar_to_tensor(0.0), type_x)
|
|
288
|
-
if not
|
|
288
|
+
if not F.is_sequence_value_unknown(x_shape):
|
|
289
289
|
depth = 1
|
|
290
290
|
if x_shape:
|
|
291
291
|
depth = x_shape[axis]
|
|
@@ -308,35 +308,6 @@ def _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout):
|
|
|
308
308
|
return dx
|
|
309
309
|
|
|
310
310
|
|
|
311
|
-
@bprop_getters.register(P.MatMul)
|
|
312
|
-
def bprop_matmul(self):
|
|
313
|
-
"""Grad definition for `MatMul` operation."""
|
|
314
|
-
ta = self.transpose_a
|
|
315
|
-
tb = self.transpose_b
|
|
316
|
-
mul1 = P.MatMul(transpose_a=(ta and tb),
|
|
317
|
-
transpose_b=(ta or (not tb)))
|
|
318
|
-
mul2 = P.MatMul(transpose_a=((not ta) or tb),
|
|
319
|
-
transpose_b=(ta and tb))
|
|
320
|
-
|
|
321
|
-
def bprop(x, w, out, dout):
|
|
322
|
-
conj = P.Conj()
|
|
323
|
-
origin_dtype = x.dtype
|
|
324
|
-
if origin_dtype in (mstype.complex64, mstype.complex128):
|
|
325
|
-
x = conj(x)
|
|
326
|
-
w = conj(w)
|
|
327
|
-
if ta:
|
|
328
|
-
dx = mul1(w, dout)
|
|
329
|
-
else:
|
|
330
|
-
dx = mul1(dout, w)
|
|
331
|
-
if tb:
|
|
332
|
-
dw = mul2(dout, x)
|
|
333
|
-
else:
|
|
334
|
-
dw = mul2(x, dout)
|
|
335
|
-
return dx, dw
|
|
336
|
-
|
|
337
|
-
return bprop
|
|
338
|
-
|
|
339
|
-
|
|
340
311
|
@bprop_getters.register(P.BatchMatMul)
|
|
341
312
|
def bprop_batchmatmul(self):
|
|
342
313
|
"""Grad definition for `BatchMatMul` operation."""
|
|
@@ -361,16 +332,6 @@ def bprop_batchmatmul(self):
|
|
|
361
332
|
return bprop
|
|
362
333
|
|
|
363
334
|
|
|
364
|
-
@bprop_getters.register(P.Add)
|
|
365
|
-
def get_bprop_add(self):
|
|
366
|
-
"""Grad definition for `Add` operation."""
|
|
367
|
-
|
|
368
|
-
def bprop(x, y, out, dout):
|
|
369
|
-
return binop_grad_common(x, y, dout, dout)
|
|
370
|
-
|
|
371
|
-
return bprop
|
|
372
|
-
|
|
373
|
-
|
|
374
335
|
@bprop_getters.register(P.TensorAdd)
|
|
375
336
|
def get_bprop_tensor_add(self):
|
|
376
337
|
"""Grad definition for `Add` operation."""
|
|
@@ -397,35 +358,14 @@ def get_bprop_matrix_inverse(self):
|
|
|
397
358
|
return bprop
|
|
398
359
|
|
|
399
360
|
|
|
400
|
-
@bprop_getters.register(P.Neg)
|
|
401
|
-
def get_bprop_neg(self):
|
|
402
|
-
"""Grad definition for `Neg` operation."""
|
|
403
|
-
neg_grad = P.Neg()
|
|
404
|
-
|
|
405
|
-
def bprop(x, out, dout):
|
|
406
|
-
dx = neg_grad(dout)
|
|
407
|
-
return (dx,)
|
|
408
|
-
|
|
409
|
-
return bprop
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
@bprop_getters.register(P.Sub)
|
|
413
|
-
def get_bprop_sub(self):
|
|
414
|
-
"""Grad definition for `Sub` operation."""
|
|
415
|
-
neg_func = P.Neg()
|
|
416
|
-
|
|
417
|
-
def bprop(x, y, out, dout):
|
|
418
|
-
return binop_grad_common(x, y, dout, neg_func(dout))
|
|
419
|
-
|
|
420
|
-
return bprop
|
|
421
|
-
|
|
422
|
-
|
|
423
361
|
@bprop_getters.register(P.Mul)
|
|
424
362
|
def get_bprop_mul(self):
|
|
425
363
|
"""Grad definition for `Mul` operation."""
|
|
426
364
|
mul_func = P.Mul()
|
|
427
365
|
|
|
428
366
|
def bprop(x, y, out, dout):
|
|
367
|
+
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
368
|
+
raise TypeError("For 'Mul', gradient not support for complex type currently.")
|
|
429
369
|
bc_dx = mul_func(y, dout)
|
|
430
370
|
bc_dy = mul_func(x, dout)
|
|
431
371
|
return binop_grad_common(x, y, bc_dx, bc_dy)
|
|
@@ -441,6 +381,8 @@ def get_bprop_real_div(self):
|
|
|
441
381
|
mul_op = P.Mul()
|
|
442
382
|
|
|
443
383
|
def bprop(x, y, out, dout):
|
|
384
|
+
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
385
|
+
raise TypeError("For 'RealDiv', gradient not support for complex type currently.")
|
|
444
386
|
bc_x = div_op(dout, y)
|
|
445
387
|
bc_y = neg(mul_op(bc_x, out))
|
|
446
388
|
return binop_grad_common(x, y, bc_x, bc_y)
|
|
@@ -501,7 +443,7 @@ def get_bprop_floor(self):
|
|
|
501
443
|
dtype_ = P.DType()
|
|
502
444
|
|
|
503
445
|
def bprop(x, out, dout):
|
|
504
|
-
if
|
|
446
|
+
if F.is_sequence_value_unknown(shape_(x)):
|
|
505
447
|
bc_x = zeros_like(x)
|
|
506
448
|
else:
|
|
507
449
|
bc_x = fill_(dtype_(x), shape_(x), 0.)
|
|
@@ -518,7 +460,7 @@ def get_bprop_ceil(self):
|
|
|
518
460
|
dtype_ = P.DType()
|
|
519
461
|
|
|
520
462
|
def bprop(x, out, dout):
|
|
521
|
-
if
|
|
463
|
+
if F.is_sequence_value_unknown(shape_(x)):
|
|
522
464
|
bc_x = zeros_like(x)
|
|
523
465
|
else:
|
|
524
466
|
bc_x = fill_(dtype_(x), shape_(x), 0.)
|
|
@@ -537,6 +479,36 @@ def get_bprop_floordiv(self):
|
|
|
537
479
|
return bprop
|
|
538
480
|
|
|
539
481
|
|
|
482
|
+
@bprop_getters.register(P.BitwiseAnd)
|
|
483
|
+
def get_bprop_bitwiseand(self):
|
|
484
|
+
"""Grad definition for `BitwiseAnd` operation."""
|
|
485
|
+
|
|
486
|
+
def bprop(x, y, out, dout):
|
|
487
|
+
return zeros_like(x), zeros_like(y)
|
|
488
|
+
|
|
489
|
+
return bprop
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
@bprop_getters.register(P.BitwiseOr)
|
|
493
|
+
def get_bprop_bitwiseor(self):
|
|
494
|
+
"""Grad definition for `BitwiseOr` operation."""
|
|
495
|
+
|
|
496
|
+
def bprop(x, y, out, dout):
|
|
497
|
+
return zeros_like(x), zeros_like(y)
|
|
498
|
+
|
|
499
|
+
return bprop
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
@bprop_getters.register(P.BitwiseXor)
|
|
503
|
+
def get_bprop_bitwisexor(self):
|
|
504
|
+
"""Grad definition for `BitwiseXor` operation."""
|
|
505
|
+
|
|
506
|
+
def bprop(x, y, out, dout):
|
|
507
|
+
return zeros_like(x), zeros_like(y)
|
|
508
|
+
|
|
509
|
+
return bprop
|
|
510
|
+
|
|
511
|
+
|
|
540
512
|
@bprop_getters.register(P.FloorMod)
|
|
541
513
|
def get_bprop_floormod(self):
|
|
542
514
|
"""Grad definition for `FloorMod` operation."""
|
|
@@ -594,7 +566,7 @@ def get_bprop_square(self):
|
|
|
594
566
|
def bprop(x, out, dout):
|
|
595
567
|
temp = mul_func(dout, x)
|
|
596
568
|
shape_x = shape_op(x)
|
|
597
|
-
if
|
|
569
|
+
if F.is_sequence_value_unknown(shape_x):
|
|
598
570
|
fill_value = dyn_fill(dtype(temp), dyn_shape_op(x), 2.0)
|
|
599
571
|
else:
|
|
600
572
|
fill_value = fill_func(dtype(temp), shape_x, 2.0)
|
|
@@ -644,12 +616,12 @@ def get_bprop_square_sum_all(self):
|
|
|
644
616
|
def bprop(x, y, out, dout):
|
|
645
617
|
temp_x = mul_func(dout[0], x)
|
|
646
618
|
temp_y = mul_func(dout[1], y)
|
|
647
|
-
if
|
|
619
|
+
if F.is_sequence_value_unknown(shape_op(x)):
|
|
648
620
|
dx = mul_func(dyn_fill(dtype(temp_x), dyn_shape_op(x), 2.0), temp_x)
|
|
649
621
|
else:
|
|
650
622
|
dx = mul_func(fill_func(dtype(temp_x), shape_op(x), 2.0), temp_x)
|
|
651
623
|
|
|
652
|
-
if
|
|
624
|
+
if F.is_sequence_value_unknown(shape_op(y)):
|
|
653
625
|
dy = mul_func(dyn_fill(dtype(temp_y), dyn_shape_op(y), 2.0), temp_y)
|
|
654
626
|
else:
|
|
655
627
|
dy = mul_func(fill_func(dtype(temp_y), shape_op(y), 2.0), temp_y)
|
|
@@ -792,9 +764,11 @@ def get_bprop_pow(self):
|
|
|
792
764
|
ln = P.Log()
|
|
793
765
|
|
|
794
766
|
def bprop(x, power, out, dout):
|
|
767
|
+
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
768
|
+
raise TypeError("For 'Pow', gradient not support for complex type currently.")
|
|
795
769
|
bc_dx = power * pow_op(x, power - 1.0) * dout
|
|
796
770
|
shape_x = shape_op(x)
|
|
797
|
-
if
|
|
771
|
+
if F.is_sequence_value_unknown(shape_x):
|
|
798
772
|
x = F.select(x < 0, dyn_fill(F.dtype(x), dyn_shape_op(x), 1), x)
|
|
799
773
|
else:
|
|
800
774
|
x = F.select(x < 0, F.fill(F.dtype(x), F.shape(x), 1), x)
|
|
@@ -888,21 +862,31 @@ def get_bprop_cumsum(self):
|
|
|
888
862
|
return bprop
|
|
889
863
|
|
|
890
864
|
|
|
891
|
-
@
|
|
865
|
+
@_primexpr
|
|
892
866
|
def _split_shape_index(input_shape, axis):
|
|
893
867
|
"""Calculate reduce_prod grad transpose indices and perm shape."""
|
|
894
868
|
rank = len(input_shape)
|
|
895
869
|
if isinstance(axis, int):
|
|
896
870
|
axis = tuple([axis])
|
|
897
871
|
reduction_indices = tuple([(i + rank) % rank for i in axis])
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
872
|
+
other_indices_list = []
|
|
873
|
+
for i in range(rank):
|
|
874
|
+
if i not in reduction_indices and i not in other_indices_list:
|
|
875
|
+
other_indices_list.append(i)
|
|
876
|
+
other_indices = tuple(other_indices_list)
|
|
877
|
+
reduced_list = [1] + [input_shape[i] for i in reduction_indices]
|
|
878
|
+
other_list = [1] + [input_shape[i] for i in other_indices]
|
|
879
|
+
reduced_num = 1
|
|
880
|
+
for i in reduced_list:
|
|
881
|
+
reduced_num = reduced_num * i
|
|
882
|
+
other_num = 1
|
|
883
|
+
for i in other_list:
|
|
884
|
+
other_num = other_num * i
|
|
901
885
|
perm = reduction_indices + other_indices
|
|
902
886
|
return tuple([reduced_num, other_num]), perm
|
|
903
887
|
|
|
904
888
|
|
|
905
|
-
@
|
|
889
|
+
@_primexpr
|
|
906
890
|
def _invert_permutation(perm):
|
|
907
891
|
"""Calculate invert permutation."""
|
|
908
892
|
out = [0] * len(perm)
|
|
@@ -940,9 +924,15 @@ def get_bprop_reduceprod(self):
|
|
|
940
924
|
|
|
941
925
|
def bprop(x, axis, out, dout):
|
|
942
926
|
"""Grad definition for `Product` operation."""
|
|
927
|
+
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
928
|
+
raise TypeError("The 'ReduceProd', gradient not support for complex type currently.")
|
|
943
929
|
# Expand dout to full input shape
|
|
944
930
|
input_shape = shape_op(x)
|
|
945
|
-
if
|
|
931
|
+
if input_shape == ():
|
|
932
|
+
dx = _sum_grad(x, axis, dout)
|
|
933
|
+
return dx, zeros_like(axis)
|
|
934
|
+
|
|
935
|
+
if F.is_sequence_value_unknown(input_shape):
|
|
946
936
|
input_shape = dyn_shape_op(x)
|
|
947
937
|
input_shape = P.Cast()(input_shape, ms.int64)
|
|
948
938
|
output_shape_kept_dims = _dyn_reduced_shape(input_shape, axis, x)
|
|
@@ -953,14 +943,14 @@ def get_bprop_reduceprod(self):
|
|
|
953
943
|
dout = reshape(dout, output_shape_kept_dims)
|
|
954
944
|
|
|
955
945
|
# Pack all reduced dimensions into a single one, so we can perform the cumprod ops.
|
|
956
|
-
if
|
|
946
|
+
if F.is_sequence_value_unknown(shape_op(x)):
|
|
957
947
|
pack_shape, perm = _split_dyn_shape_index(x, axis)
|
|
958
948
|
else:
|
|
959
949
|
pack_shape, perm = _split_shape_index(shape_op(x), axis)
|
|
960
950
|
|
|
961
951
|
permuted = transpose(x, perm)
|
|
962
952
|
permuted_shape = shape_op(permuted)
|
|
963
|
-
if
|
|
953
|
+
if F.is_sequence_value_unknown(permuted_shape):
|
|
964
954
|
permuted_shape = dyn_shape_op(permuted)
|
|
965
955
|
pack_shape = create_tensor_by_element(pack_shape)
|
|
966
956
|
reshaped = reshape(permuted, pack_shape)
|
|
@@ -972,7 +962,7 @@ def get_bprop_reduceprod(self):
|
|
|
972
962
|
|
|
973
963
|
# Invert the transpose and reshape operations.
|
|
974
964
|
# Make sure to set the statically known shape information through a reshape.
|
|
975
|
-
if
|
|
965
|
+
if F.is_sequence_value_unknown(shape_op(permuted)):
|
|
976
966
|
dout = DynamicBroadcastTo()(dout, input_shape)
|
|
977
967
|
out = transpose(y, dyn_invert_permutation(perm)) * dout
|
|
978
968
|
else:
|
|
@@ -1027,6 +1017,8 @@ def get_bprop_reducemax(self):
|
|
|
1027
1017
|
"""Grad definition for `Max` operation."""
|
|
1028
1018
|
|
|
1029
1019
|
def bprop(x, axis, out, dout):
|
|
1020
|
+
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
1021
|
+
raise TypeError("The 'ReduceMax', gradient not support for complex type currently.")
|
|
1030
1022
|
dx = _min_or_max_grad(x, axis, out, dout)
|
|
1031
1023
|
return (dx, zeros_like(axis))
|
|
1032
1024
|
|
|
@@ -1052,6 +1044,8 @@ def get_bprop_reducemin(self):
|
|
|
1052
1044
|
"""Grad definition for `ReduceMin` operation."""
|
|
1053
1045
|
|
|
1054
1046
|
def bprop(x, axis, out, dout):
|
|
1047
|
+
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
1048
|
+
raise TypeError("The 'ReduceMin', gradient not support for complex type currently.")
|
|
1055
1049
|
dx = _min_or_max_grad(x, axis, out, dout)
|
|
1056
1050
|
return (dx, zeros_like(axis))
|
|
1057
1051
|
|
|
@@ -1080,10 +1074,12 @@ def get_bprop_reduce_mean(self):
|
|
|
1080
1074
|
dtype = P.DType()
|
|
1081
1075
|
|
|
1082
1076
|
def bprop(x, axis, out, dout):
|
|
1077
|
+
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
1078
|
+
raise TypeError("The 'ReduceMean', gradient not support for complex type currently.")
|
|
1083
1079
|
grad = _sum_grad(x, axis, dout)
|
|
1084
1080
|
shape_x = shape_op(x)
|
|
1085
1081
|
shape_out = shape_op(out)
|
|
1086
|
-
if
|
|
1082
|
+
if F.is_sequence_value_unknown(shape_x) or F.is_sequence_value_unknown(shape_out):
|
|
1087
1083
|
shape_x = dyn_shape_op(x)
|
|
1088
1084
|
shape_out = dyn_shape_op(out)
|
|
1089
1085
|
div_shape = reduce_prod(cast(shape_x, mstype.float32), ()) /\
|
|
@@ -1091,7 +1087,7 @@ def get_bprop_reduce_mean(self):
|
|
|
1091
1087
|
dx = div_op(grad, cast(div_shape, dtype(grad)))
|
|
1092
1088
|
else:
|
|
1093
1089
|
div_shape = F.shape_mul(shape_x) / F.shape_mul(shape_out)
|
|
1094
|
-
dx = div_op(grad, cast(F.scalar_to_tensor(div_shape
|
|
1090
|
+
dx = div_op(grad, cast(F.scalar_to_tensor(div_shape), dtype(grad)))
|
|
1095
1091
|
return dx, zeros_like(axis)
|
|
1096
1092
|
|
|
1097
1093
|
return bprop
|
|
@@ -1217,16 +1213,6 @@ def get_bprop_logical_and(self):
|
|
|
1217
1213
|
return bprop
|
|
1218
1214
|
|
|
1219
1215
|
|
|
1220
|
-
@bprop_getters.register(P.LogicalOr)
|
|
1221
|
-
def get_bprop_logical_or(self):
|
|
1222
|
-
"""Grad definition for `LogicalOr` operation."""
|
|
1223
|
-
|
|
1224
|
-
def bprop(x, y, out, dout):
|
|
1225
|
-
return zeros_like(x), zeros_like(y)
|
|
1226
|
-
|
|
1227
|
-
return bprop
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
1216
|
@bprop_getters.register(P.NPUAllocFloatStatus)
|
|
1231
1217
|
def get_bprop_npu_alloc_float_status(self):
|
|
1232
1218
|
"""Grad definition for `NPUAllocFloatStatus` operation."""
|
|
@@ -1424,6 +1410,9 @@ def get_bprop_cosh(self):
|
|
|
1424
1410
|
sinh = P.Sinh()
|
|
1425
1411
|
|
|
1426
1412
|
def bprop(x, out, dout):
|
|
1413
|
+
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
1414
|
+
raise TypeError("The 'Cosh', gradient not support for complex type currently.")
|
|
1415
|
+
|
|
1427
1416
|
dx = sinh(x) * dout
|
|
1428
1417
|
return (dx,)
|
|
1429
1418
|
|
|
@@ -1454,16 +1443,6 @@ def get_bprop_conj(self):
|
|
|
1454
1443
|
return bprop
|
|
1455
1444
|
|
|
1456
1445
|
|
|
1457
|
-
@bprop_getters.register(P.ScalarCast)
|
|
1458
|
-
def get_bprop_scalar_cast(self):
|
|
1459
|
-
"""Generate bprop for ScalarCast"""
|
|
1460
|
-
|
|
1461
|
-
def bprop(x, t, out, dout):
|
|
1462
|
-
return F.scalar_cast(dout, F.typeof(x)), zeros_like(t)
|
|
1463
|
-
|
|
1464
|
-
return bprop
|
|
1465
|
-
|
|
1466
|
-
|
|
1467
1446
|
@bprop_getters.register(P.AccumulateNV2)
|
|
1468
1447
|
def get_bprop_scalar_accumulatenv2(self):
|
|
1469
1448
|
"""Generate bprop for AccumulateNV2"""
|
|
@@ -1577,6 +1556,9 @@ def get_bprop_tan(self):
|
|
|
1577
1556
|
cos = P.Cos()
|
|
1578
1557
|
|
|
1579
1558
|
def bprop(x, out, dout):
|
|
1559
|
+
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
1560
|
+
raise TypeError("For 'Tan', gradient not support for complex type currently.")
|
|
1561
|
+
|
|
1580
1562
|
cosx = cos(x)
|
|
1581
1563
|
secx2 = square(reciprocal(cosx))
|
|
1582
1564
|
dx = secx2 * dout
|
|
@@ -1618,6 +1600,9 @@ def get_bprop_atanh(self):
|
|
|
1618
1600
|
div = P.Div()
|
|
1619
1601
|
|
|
1620
1602
|
def bprop(x, out, dout):
|
|
1603
|
+
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
1604
|
+
raise TypeError("For 'Atanh', gradient not support for complex type currently.")
|
|
1605
|
+
|
|
1621
1606
|
tmp = 1 - power(x, 2)
|
|
1622
1607
|
dx = div(1, tmp) * dout
|
|
1623
1608
|
return (dx,)
|
|
@@ -1657,3 +1642,43 @@ def get_bprop_index_add(self):
|
|
|
1657
1642
|
return dout, zeros_like(indices), gather(dout, indices, _axis)
|
|
1658
1643
|
|
|
1659
1644
|
return bprop
|
|
1645
|
+
|
|
1646
|
+
|
|
1647
|
+
@bprop_getters.register(P.InplaceUpdate)
|
|
1648
|
+
def get_bprop_inplace_update(self):
|
|
1649
|
+
"""Grad definition for `InplaceUpdate` operation."""
|
|
1650
|
+
|
|
1651
|
+
def bprop(x, v, out, dout):
|
|
1652
|
+
return zeros_like(x), zeros_like(v)
|
|
1653
|
+
|
|
1654
|
+
return bprop
|
|
1655
|
+
|
|
1656
|
+
|
|
1657
|
+
@bprop_getters.register(P.InplaceUpdateV2)
|
|
1658
|
+
def get_bprop_inplace_update_v2(self):
|
|
1659
|
+
"""Grad definition for `InplaceUpdateV2` operation."""
|
|
1660
|
+
|
|
1661
|
+
def bprop(x, indices, v, out, dout):
|
|
1662
|
+
return zeros_like(x), zeros_like(indices), zeros_like(v)
|
|
1663
|
+
|
|
1664
|
+
return bprop
|
|
1665
|
+
|
|
1666
|
+
|
|
1667
|
+
@bprop_getters.register(P.InplaceSub)
|
|
1668
|
+
def get_bprop_inplace_sub(self):
|
|
1669
|
+
"""Grad definition for `InplaceSub` operation."""
|
|
1670
|
+
|
|
1671
|
+
def bprop(x, input_v, out, dout):
|
|
1672
|
+
return zeros_like(x), zeros_like(input_v)
|
|
1673
|
+
|
|
1674
|
+
return bprop
|
|
1675
|
+
|
|
1676
|
+
|
|
1677
|
+
@bprop_getters.register(P.InplaceAdd)
|
|
1678
|
+
def get_bprop_inplace_add(self):
|
|
1679
|
+
"""Grad definition for `InplaceAdd` operation."""
|
|
1680
|
+
|
|
1681
|
+
def bprop(x, input_v, out, dout):
|
|
1682
|
+
return zeros_like(x), zeros_like(input_v)
|
|
1683
|
+
|
|
1684
|
+
return bprop
|