mindspore 2.0.0a0__cp37-cp37m-win_amd64.whl → 2.0.0rc1__cp37-cp37m-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -2
- mindspore/_c_dataengine.cp37-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp37-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp37-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +16 -1
- mindspore/_extends/parse/parser.py +107 -22
- mindspore/_extends/parse/resources.py +0 -7
- mindspore/_extends/parse/standard_method.py +885 -413
- mindspore/amp.py +52 -57
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +38 -20
- mindspore/boost/dim_reduce.py +3 -3
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -6
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +41 -7
- mindspore/common/api.py +215 -141
- mindspore/common/dtype.py +8 -1
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +4 -2
- mindspore/common/jit_config.py +17 -13
- mindspore/common/mutable.py +33 -13
- mindspore/common/parameter.py +23 -21
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +62 -41
- mindspore/common/tensor.py +852 -1154
- mindspore/communication/__init__.py +2 -2
- mindspore/communication/_comm_helper.py +11 -4
- mindspore/communication/management.py +22 -21
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +201 -23
- mindspore/dataset/__init__.py +6 -6
- mindspore/dataset/audio/__init__.py +7 -7
- mindspore/dataset/audio/transforms.py +670 -30
- mindspore/dataset/audio/utils.py +47 -4
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +210 -14
- mindspore/dataset/core/validator_helpers.py +2 -2
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +322 -66
- mindspore/dataset/engine/datasets_audio.py +80 -76
- mindspore/dataset/engine/datasets_standard_format.py +51 -38
- mindspore/dataset/engine/datasets_text.py +232 -118
- mindspore/dataset/engine/datasets_user_defined.py +41 -17
- mindspore/dataset/engine/datasets_vision.py +746 -225
- mindspore/dataset/engine/graphdata.py +75 -10
- mindspore/dataset/engine/iterators.py +45 -5
- mindspore/dataset/engine/offload.py +48 -28
- mindspore/dataset/engine/validators.py +117 -8
- mindspore/dataset/text/__init__.py +6 -5
- mindspore/dataset/text/transforms.py +86 -3
- mindspore/dataset/text/utils.py +6 -4
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +3 -2
- mindspore/dataset/transforms/c_transforms.py +1 -1
- mindspore/dataset/transforms/transforms.py +2 -2
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +2 -3
- mindspore/dataset/vision/c_transforms.py +9 -9
- mindspore/dataset/vision/py_transforms.py +5 -5
- mindspore/dataset/vision/py_transforms_util.py +2 -0
- mindspore/dataset/vision/transforms.py +160 -161
- mindspore/dataset/vision/utils.py +3 -3
- mindspore/experimental/map_parameter.py +38 -26
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +44 -9
- mindspore/include/api/delegate.h +1 -1
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_parallel_runner.h +2 -2
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +19 -3
- mindspore/include/api/types.h +3 -3
- mindspore/include/dataset/constants.h +7 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filereader.py +18 -0
- mindspore/mindrecord/filewriter.py +197 -34
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
- mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
- mindspore/mindrecord/tools/csv_to_mr.py +3 -3
- mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/__init__.py +0 -4
- mindspore/nn/cell.py +204 -132
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +7 -6
- mindspore/nn/layer/__init__.py +5 -4
- mindspore/nn/layer/activation.py +40 -89
- mindspore/nn/layer/basic.py +255 -624
- mindspore/nn/layer/channel_shuffle.py +7 -6
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +41 -4
- mindspore/nn/layer/conv.py +64 -28
- mindspore/nn/layer/dense.py +9 -8
- mindspore/nn/layer/embedding.py +27 -25
- mindspore/nn/layer/image.py +53 -46
- mindspore/nn/layer/math.py +97 -105
- mindspore/nn/layer/normalization.py +117 -86
- mindspore/nn/layer/padding.py +185 -95
- mindspore/nn/layer/pooling.py +817 -414
- mindspore/nn/layer/rnn_cells.py +10 -15
- mindspore/nn/layer/rnns.py +37 -38
- mindspore/nn/layer/thor_layer.py +11 -12
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +5 -4
- mindspore/nn/loss/loss.py +334 -199
- mindspore/nn/optim/ada_grad.py +6 -6
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +4 -5
- mindspore/nn/optim/adam.py +126 -62
- mindspore/nn/optim/adamax.py +3 -4
- mindspore/nn/optim/adasum.py +6 -6
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +67 -38
- mindspore/nn/optim/lamb.py +4 -5
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +43 -4
- mindspore/nn/optim/momentum.py +6 -5
- mindspore/nn/optim/optimizer.py +3 -1
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +1 -1
- mindspore/nn/optim/rprop.py +8 -9
- mindspore/nn/optim/sgd.py +19 -13
- mindspore/nn/optim/thor.py +10 -15
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +4 -4
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +9 -15
- mindspore/nn/probability/distribution/bernoulli.py +3 -3
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +5 -7
- mindspore/nn/probability/distribution/cauchy.py +3 -3
- mindspore/nn/probability/distribution/distribution.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +3 -3
- mindspore/nn/probability/distribution/half_normal.py +15 -11
- mindspore/nn/probability/distribution/laplace.py +16 -13
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/normal.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/student_t.py +20 -15
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +27 -10
- mindspore/nn/wrap/grad_reducer.py +2 -2
- mindspore/nn/wrap/loss_scale.py +40 -24
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +35 -30
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +22 -19
- mindspore/numpy/utils.py +1 -1
- mindspore/numpy/utils_const.py +108 -58
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +86 -117
- mindspore/ops/_grad/grad_base.py +23 -1
- mindspore/ops/_grad/grad_clip_ops.py +2 -3
- mindspore/ops/_grad/grad_comm_ops.py +34 -24
- mindspore/ops/_grad/grad_implementations.py +9 -45
- mindspore/ops/_grad/grad_inner_ops.py +47 -4
- mindspore/ops/_grad/grad_math_ops.py +142 -117
- mindspore/ops/_grad/grad_nn_ops.py +71 -165
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +7 -6
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
- mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
- mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
- mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -611
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_register_for_op.py +1 -0
- mindspore/ops/_utils/__init__.py +1 -2
- mindspore/ops/_utils/utils.py +19 -40
- mindspore/ops/_vmap/vmap_array_ops.py +116 -38
- mindspore/ops/_vmap/vmap_base.py +16 -9
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
- mindspore/ops/_vmap/vmap_image_ops.py +12 -5
- mindspore/ops/_vmap/vmap_math_ops.py +46 -5
- mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
- mindspore/ops/_vmap/vmap_random_ops.py +1 -1
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
- mindspore/ops/composite/__init__.py +7 -8
- mindspore/ops/composite/base.py +101 -47
- mindspore/ops/composite/math_ops.py +188 -158
- mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
- mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
- mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
- mindspore/ops/function/__init__.py +152 -8
- mindspore/ops/function/array_func.py +2555 -674
- mindspore/ops/function/clip_func.py +209 -13
- mindspore/ops/function/debug_func.py +2 -2
- mindspore/ops/function/grad/__init__.py +2 -1
- mindspore/ops/function/grad/grad_func.py +147 -62
- mindspore/ops/function/image_func.py +54 -38
- mindspore/ops/function/linalg_func.py +167 -16
- mindspore/ops/function/math_func.py +4849 -1492
- mindspore/ops/function/nn_func.py +2573 -988
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +3 -3
- mindspore/ops/function/random_func.py +790 -73
- mindspore/ops/function/sparse_func.py +98 -78
- mindspore/ops/function/sparse_unary_func.py +54 -53
- mindspore/ops/function/spectral_func.py +27 -24
- mindspore/ops/function/vmap_func.py +22 -2
- mindspore/ops/functional.py +97 -37
- mindspore/ops/op_info_register.py +70 -28
- mindspore/ops/operations/__init__.py +47 -14
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +5 -5
- mindspore/ops/operations/_grad_ops.py +276 -187
- mindspore/ops/operations/_inner_ops.py +319 -113
- mindspore/ops/operations/_ms_kernel.py +10 -8
- mindspore/ops/operations/_ocr_ops.py +9 -9
- mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
- mindspore/ops/operations/_quant_ops.py +137 -102
- mindspore/ops/operations/_rl_inner_ops.py +121 -60
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1004 -2
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +801 -466
- mindspore/ops/operations/comm_ops.py +51 -49
- mindspore/ops/operations/control_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +123 -44
- mindspore/ops/operations/debug_ops.py +24 -24
- mindspore/ops/operations/image_ops.py +240 -153
- mindspore/ops/operations/inner_ops.py +34 -50
- mindspore/ops/operations/linalg_ops.py +31 -9
- mindspore/ops/operations/math_ops.py +988 -757
- mindspore/ops/operations/nn_ops.py +965 -819
- mindspore/ops/operations/other_ops.py +51 -40
- mindspore/ops/operations/random_ops.py +204 -122
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +254 -93
- mindspore/ops/operations/spectral_ops.py +35 -3
- mindspore/ops/primitive.py +111 -9
- mindspore/parallel/_auto_parallel_context.py +189 -83
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +99 -7
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +7 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
- mindspore/parallel/_utils.py +1 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +37 -34
- mindspore/parallel/shard.py +17 -18
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +69 -47
- mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
- mindspore/profiler/parser/base_timeline_generator.py +49 -56
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
- mindspore/profiler/parser/hwts_log_parser.py +1 -1
- mindspore/profiler/parser/integrator.py +15 -14
- mindspore/profiler/parser/minddata_analyzer.py +2 -2
- mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +2 -1
- mindspore/profiler/profiling.py +218 -186
- mindspore/rewrite/__init__.py +3 -1
- mindspore/rewrite/api/node.py +1 -114
- mindspore/rewrite/api/node_type.py +3 -0
- mindspore/rewrite/api/pattern_engine.py +31 -1
- mindspore/rewrite/api/scoped_value.py +4 -4
- mindspore/rewrite/api/symbol_tree.py +3 -78
- mindspore/rewrite/api/tree_node_helper.py +1 -1
- mindspore/rewrite/ast_creator_register.py +1 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -2
- mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
- mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
- mindspore/rewrite/namespace.py +0 -2
- mindspore/rewrite/node.py +157 -11
- mindspore/rewrite/parsers/assign_parser.py +231 -53
- mindspore/rewrite/parsers/class_def_parser.py +187 -109
- mindspore/rewrite/parsers/for_parser.py +24 -14
- mindspore/rewrite/parsers/function_def_parser.py +21 -4
- mindspore/rewrite/parsers/if_parser.py +6 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +256 -133
- mindspore/rewrite/symbol_tree_builder.py +38 -1
- mindspore/run_check/_check_version.py +69 -63
- mindspore/run_check/run_check.py +2 -1
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +1 -1
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +273 -102
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +2 -2
- mindspore/train/callback/_checkpoint.py +3 -3
- mindspore/train/callback/_early_stop.py +3 -3
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +29 -31
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +3 -3
- mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
- mindspore/train/callback/_summary_collector.py +23 -16
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +15 -3
- mindspore/train/dataset_helper.py +10 -15
- mindspore/train/loss_scale_manager.py +8 -11
- mindspore/train/metrics/__init__.py +1 -1
- mindspore/train/metrics/bleu_score.py +1 -1
- mindspore/train/metrics/confusion_matrix.py +1 -1
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/dice.py +2 -2
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +4 -3
- mindspore/train/metrics/mean_surface_distance.py +2 -2
- mindspore/train/metrics/occlusion_sensitivity.py +1 -1
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +1 -1
- mindspore/train/metrics/recall.py +1 -1
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +45 -28
- mindspore/train/serialization.py +295 -188
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -13
- mindspore/train/train_thor/convert_utils.py +2 -2
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
- mindspore/compression/__init__.py +0 -19
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -515
- mindspore/compression/quant/__init__.py +0 -28
- mindspore/compression/quant/qat.py +0 -634
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -140
- mindspore/nn/probability/dpn/vae/vae.py +0 -124
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
- mindspore/ops/composite/array_ops.py +0 -241
- mindspore/ops/composite/clip_ops.py +0 -134
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,296 @@
|
|
|
1
|
+
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
"""grad_sequence_ops"""
|
|
17
|
+
|
|
18
|
+
from mindspore.ops.operations import _sequence_ops as seq
|
|
19
|
+
from mindspore.ops import operations as P
|
|
20
|
+
from mindspore.ops import functional as F
|
|
21
|
+
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
22
|
+
from mindspore.ops._grad.grad_base import bprop_getters
|
|
23
|
+
from mindspore.ops.primitive import Primitive
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
tuple_setitem = Primitive("tuple_setitem")
|
|
27
|
+
list_setitem = Primitive("list_setitem")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@bprop_getters.register(seq.SequenceCount)
|
|
31
|
+
def get_bprop_count(self):
|
|
32
|
+
"""Generate bprop for SequenceCount"""
|
|
33
|
+
|
|
34
|
+
def bprop(x, y, out, dout):
|
|
35
|
+
return (zeros_like(x), zeros_like(y))
|
|
36
|
+
|
|
37
|
+
return bprop
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@bprop_getters.register(seq.sequence_len)
|
|
41
|
+
def get_bprop_sequence_len(self):
|
|
42
|
+
"""Generate bprop for sequence_len"""
|
|
43
|
+
def bprop(x, out, dout):
|
|
44
|
+
return (zeros_like(x),)
|
|
45
|
+
|
|
46
|
+
return bprop
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@bprop_getters.register(seq.SequenceAdd)
|
|
50
|
+
def get_bprop_sequence_add(self):
|
|
51
|
+
"""Generate bprop for SequenceAdd"""
|
|
52
|
+
def bprop(x, y, out, dout):
|
|
53
|
+
out_offset = seq.SequenceAddOffset()(x, y)
|
|
54
|
+
dx = seq.SequenceSlice()(dout, out_offset[0], len(x), 1)
|
|
55
|
+
dy = seq.SequenceSlice()(dout, out_offset[1], len(x) + len(y), 1)
|
|
56
|
+
|
|
57
|
+
return (dx, dy)
|
|
58
|
+
|
|
59
|
+
return bprop
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@bprop_getters.register(seq.SequenceSlice)
|
|
63
|
+
def get_bprop_slice(self):
|
|
64
|
+
"""Generate bprop for SequenceSlice"""
|
|
65
|
+
|
|
66
|
+
def bprop(x, start, stop, step, out, dout):
|
|
67
|
+
dx = seq.SequenceSliceGrad()(dout, x, start, stop, step)
|
|
68
|
+
return (dx, zeros_like(start), zeros_like(stop), zeros_like(step))
|
|
69
|
+
|
|
70
|
+
return bprop
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@bprop_getters.register(seq.SequenceIndex)
|
|
74
|
+
def get_bprop_index(self):
|
|
75
|
+
"""Generate bprop for SequenceIndex"""
|
|
76
|
+
|
|
77
|
+
def bprop(x, y, start, end, out, dout):
|
|
78
|
+
return (zeros_like(x), zeros_like(y), zeros_like(start), zeros_like(end))
|
|
79
|
+
|
|
80
|
+
return bprop
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@bprop_getters.register(seq.InSequence)
|
|
84
|
+
def get_bprop_insequence(self):
|
|
85
|
+
"""Generate bprop for InSequence"""
|
|
86
|
+
|
|
87
|
+
def bprop(x, y, out, dout):
|
|
88
|
+
return (zeros_like(x), seq.SequenceZerosLike()(y))
|
|
89
|
+
|
|
90
|
+
return bprop
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@bprop_getters.register("tuple_equal")
|
|
94
|
+
@bprop_getters.register("list_equal")
|
|
95
|
+
def get_bprop_seq_equal(self):
|
|
96
|
+
"""Generate bprop for tuple_equal and list_equal"""
|
|
97
|
+
|
|
98
|
+
def bprop(x, y, out, dout):
|
|
99
|
+
return (zeros_like(x), zeros_like(y))
|
|
100
|
+
|
|
101
|
+
return bprop
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@bprop_getters.register("shape_mul")
|
|
105
|
+
def get_bprop_shape_mul(self):
|
|
106
|
+
"""Generate bprop for tuple_equal and list_equal"""
|
|
107
|
+
|
|
108
|
+
def bprop(x, out, dout):
|
|
109
|
+
dx = seq.ShapeMulGrad()(x, dout)
|
|
110
|
+
return (dx,)
|
|
111
|
+
|
|
112
|
+
return bprop
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@bprop_getters.register("tuple_setitem")
|
|
116
|
+
def get_bprop_tuple_setitem(self):
|
|
117
|
+
"""Generate bprop for TupleSetItem and ListSetItem"""
|
|
118
|
+
|
|
119
|
+
def bprop(x, idx, value, out, dout):
|
|
120
|
+
d_x = tuple_setitem(dout, idx, zeros_like(value))
|
|
121
|
+
d_value = dout[idx]
|
|
122
|
+
d_idx = 0
|
|
123
|
+
return (d_x, zeros_like(d_idx), d_value)
|
|
124
|
+
|
|
125
|
+
return bprop
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@bprop_getters.register("list_setitem")
|
|
129
|
+
def get_bprop_lsit_setitem(self):
|
|
130
|
+
"""Generate bprop for TupleSetItem and ListSetItem"""
|
|
131
|
+
|
|
132
|
+
def bprop(x, idx, value, out, dout):
|
|
133
|
+
d_x = list_setitem(dout, idx, zeros_like(value))
|
|
134
|
+
d_value = dout[idx]
|
|
135
|
+
d_idx = 0
|
|
136
|
+
return (d_x, zeros_like(d_idx), d_value)
|
|
137
|
+
|
|
138
|
+
return bprop
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@bprop_getters.register(seq.ListAppend)
|
|
142
|
+
def get_bprop_list_append(self):
|
|
143
|
+
"""Generate bprop for ListAppend"""
|
|
144
|
+
|
|
145
|
+
def bprop(x, value, out, dout):
|
|
146
|
+
d_x = seq.ListAppendAndInsertGrad()(dout, -1)
|
|
147
|
+
return (d_x, zeros_like(value))
|
|
148
|
+
|
|
149
|
+
return bprop
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
@bprop_getters.register(seq.ListInsert)
|
|
153
|
+
def get_bprop_list_insert(self):
|
|
154
|
+
"""Generate bprop for ListInsert"""
|
|
155
|
+
|
|
156
|
+
def bprop(x, idx, value, out, dout):
|
|
157
|
+
d_x = seq.ListAppendAndInsertGrad()(dout, idx)
|
|
158
|
+
return (d_x, zeros_like(idx), zeros_like(value))
|
|
159
|
+
|
|
160
|
+
return bprop
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
@bprop_getters.register(seq.TupleToTensor)
|
|
164
|
+
def get_bprop_tuple_to_tensor(self):
|
|
165
|
+
"""Generate bprop for TupleToTensor"""
|
|
166
|
+
|
|
167
|
+
def bprop(x, dtype, out, dout):
|
|
168
|
+
tuple_type = F.typeof(x)
|
|
169
|
+
dout = P.Cast()(dout, tuple_type)
|
|
170
|
+
d_x = seq.TensorToTuple()(dout)
|
|
171
|
+
return (d_x, zeros_like(dtype))
|
|
172
|
+
|
|
173
|
+
return bprop
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
@bprop_getters.register(seq.ListToTensor)
|
|
177
|
+
def get_bprop_list_to_tensor(self):
|
|
178
|
+
"""Generate bprop for ListToTensor"""
|
|
179
|
+
|
|
180
|
+
def bprop(x, dtype, out, dout):
|
|
181
|
+
tuple_type = F.typeof(x)
|
|
182
|
+
dout = P.Cast()(dout, tuple_type)
|
|
183
|
+
d_x = seq.TensorToList()(dout)
|
|
184
|
+
return (d_x, zeros_like(dtype))
|
|
185
|
+
|
|
186
|
+
return bprop
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
@bprop_getters.register(P.ScalarToTensor)
|
|
190
|
+
def get_bprop_scalar_to_tensor(self):
|
|
191
|
+
"""Generate bprop for ScalarToTensor"""
|
|
192
|
+
|
|
193
|
+
def bprop(x, dtype, out, dout):
|
|
194
|
+
scalar_type = F.typeof(x)
|
|
195
|
+
dout = P.Cast()(dout, scalar_type)
|
|
196
|
+
d_x = seq.TensorToScalar()(dout)
|
|
197
|
+
return (d_x, zeros_like(dtype))
|
|
198
|
+
|
|
199
|
+
return bprop
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
@bprop_getters.register(seq.TensorToTuple)
|
|
203
|
+
def get_bprop_tensor_to_tuple(self):
|
|
204
|
+
"""Generate bprop for TensorToTuple"""
|
|
205
|
+
|
|
206
|
+
def bprop(x, out, dout):
|
|
207
|
+
dtype = F.typeof(x)
|
|
208
|
+
d_x = seq.TupleToTensor()(dout, dtype)
|
|
209
|
+
return (d_x,)
|
|
210
|
+
|
|
211
|
+
return bprop
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
@bprop_getters.register(seq.TensorToList)
|
|
215
|
+
def get_bprop_tensor_to_list(self):
|
|
216
|
+
"""Generate bprop for TensorToList"""
|
|
217
|
+
|
|
218
|
+
def bprop(x, out, dout):
|
|
219
|
+
dtype = F.typeof(x)
|
|
220
|
+
d_x = seq.ListToTensor()(dout, dtype)
|
|
221
|
+
return (d_x,)
|
|
222
|
+
|
|
223
|
+
return bprop
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
@bprop_getters.register(seq.TensorToScalar)
|
|
227
|
+
def get_bprop_tensor_to_scalar(self):
|
|
228
|
+
"""Generate bprop for TensorToScalar"""
|
|
229
|
+
|
|
230
|
+
def bprop(x, out, dout):
|
|
231
|
+
dtype = F.typeof(x)
|
|
232
|
+
d_x = P.ScalarToTensor()(dout, dtype)
|
|
233
|
+
return (d_x,)
|
|
234
|
+
|
|
235
|
+
return bprop
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@bprop_getters.register("tuple_le")
|
|
239
|
+
@bprop_getters.register("tuple_lt")
|
|
240
|
+
@bprop_getters.register("list_le")
|
|
241
|
+
@bprop_getters.register("list_lt")
|
|
242
|
+
def get_bprop_less(self):
|
|
243
|
+
"""Generate bprop for SequenceLessThan and SequenceLessEqual"""
|
|
244
|
+
|
|
245
|
+
def bprop(x, y, out, dout):
|
|
246
|
+
return (zeros_like(x), zeros_like(y))
|
|
247
|
+
|
|
248
|
+
return bprop
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
@bprop_getters.register(seq.SequenceMul)
|
|
252
|
+
def get_bprop_mul(self):
|
|
253
|
+
"""Generate bprop for SequenceMul"""
|
|
254
|
+
|
|
255
|
+
def bprop(x, y, out, dout):
|
|
256
|
+
dx = x
|
|
257
|
+
if isinstance(x, tuple):
|
|
258
|
+
for i in range(len(x)):
|
|
259
|
+
dx = tuple_setitem(dx, i, dout[i])
|
|
260
|
+
else:
|
|
261
|
+
for i in range(len(x)):
|
|
262
|
+
dx = list_setitem(dx, i, dout[i])
|
|
263
|
+
return (dx, zeros_like(y))
|
|
264
|
+
|
|
265
|
+
return bprop
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
@bprop_getters.register(seq.SequenceMin)
|
|
269
|
+
@bprop_getters.register(seq.SequenceMax)
|
|
270
|
+
def get_bprop_max_min(self):
|
|
271
|
+
"""Generate bprop for SequenceMax and SequenceMax"""
|
|
272
|
+
|
|
273
|
+
def bprop(x, out, dout):
|
|
274
|
+
index = x.index(out)
|
|
275
|
+
if isinstance(x, tuple):
|
|
276
|
+
dx = tuple_setitem(zeros_like(x), index, dout)
|
|
277
|
+
else:
|
|
278
|
+
dx = list_setitem(zeros_like(x), index, dout)
|
|
279
|
+
return (dx,)
|
|
280
|
+
|
|
281
|
+
return bprop
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
@bprop_getters.register("tuple_greater_than")
|
|
285
|
+
@bprop_getters.register("list_greater_than")
|
|
286
|
+
@bprop_getters.register("tuple_greater_equal")
|
|
287
|
+
@bprop_getters.register("list_greater_equal")
|
|
288
|
+
def get_bprop_greater(self):
|
|
289
|
+
"""Generate bprop for tuple_greater_than, list_greater_than,
|
|
290
|
+
tuple_greater_equal, list_greater_equal.
|
|
291
|
+
"""
|
|
292
|
+
|
|
293
|
+
def bprop(x, y, out, dout):
|
|
294
|
+
return (zeros_like(x), zeros_like(y))
|
|
295
|
+
|
|
296
|
+
return bprop
|
|
@@ -14,7 +14,6 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
|
|
16
16
|
"""bprop primitives"""
|
|
17
|
-
from mindspore.ops._utils.utils import is_shape_unknown
|
|
18
17
|
from mindspore.ops._grad.grad_base import bprops, bprop_getters
|
|
19
18
|
from mindspore.ops.composite.multitype_ops._constexpr_utils import infer_out_shape
|
|
20
19
|
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
@@ -122,7 +121,7 @@ def get_bprop_sparse_add(self):
|
|
|
122
121
|
dx1, dx2 = sparse_add_grad(dout[1], x1_indices, x2_indices, out[0])
|
|
123
122
|
ret0 = zeros_like(x1_indices)
|
|
124
123
|
shp = shape_op(x1_values)
|
|
125
|
-
if
|
|
124
|
+
if F.is_sequence_value_unknown(shp):
|
|
126
125
|
shp = dyn_shape_op(x1_values)
|
|
127
126
|
dx1_shape = shp
|
|
128
127
|
ret1 = reshape(dx1, dx1_shape)
|
|
@@ -130,7 +129,7 @@ def get_bprop_sparse_add(self):
|
|
|
130
129
|
|
|
131
130
|
ret3 = zeros_like(x2_indices)
|
|
132
131
|
shp = shape_op(x2_values)
|
|
133
|
-
if
|
|
132
|
+
if F.is_sequence_value_unknown(shp):
|
|
134
133
|
shp = dyn_shape_op(x2_values)
|
|
135
134
|
dx2_shape = shp
|
|
136
135
|
ret4 = reshape(dx2, dx2_shape)
|
|
@@ -223,7 +222,7 @@ def get_bprop_csr_mul(self):
|
|
|
223
222
|
to index the dense input.
|
|
224
223
|
"""
|
|
225
224
|
def bprop(indptr, indices, values, shape, dense, out, dout):
|
|
226
|
-
csr_tensor_grad_value = F.csr_mul(F.make_csr_tensor(indptr, indices, dout, shape), dense)
|
|
225
|
+
csr_tensor_grad_value = F.csr_mul(F.make_csr_tensor(indptr, indices, dout, shape), dense).values
|
|
227
226
|
dense_grad_value = F.mul(dout, values)
|
|
228
227
|
dense_grad = F.make_csr_tensor(indptr, indices, dense_grad_value, shape)
|
|
229
228
|
if len(dense.shape) == 1 or dense.shape[0] == 1:
|
|
@@ -261,9 +260,11 @@ def get_bprop_csr_div(self):
|
|
|
261
260
|
shape_y = feature_dim + shape[batch_dim_dense_start:]
|
|
262
261
|
reduce_x, reduce_y = F.broadcast_gradient_args(shape_x, shape_y)
|
|
263
262
|
|
|
264
|
-
|
|
263
|
+
csr_tensor_grad = F.csr_div(F.make_csr_tensor(indptr, indices, dout, shape), dense)
|
|
265
264
|
if reduce_x:
|
|
266
|
-
csr_tensor_grad_value = P.ReduceSum(True)(
|
|
265
|
+
csr_tensor_grad_value = P.ReduceSum(True)(csr_tensor_grad.values, reduce_x)
|
|
266
|
+
else:
|
|
267
|
+
csr_tensor_grad_value = csr_tensor_grad.values
|
|
267
268
|
dense_grad_value = F.neg_tensor(F.mul(out, csr_tensor_grad_value))
|
|
268
269
|
dense_grad = F.make_csr_tensor(indptr, indices, dense_grad_value, shape)
|
|
269
270
|
if len(dense.shape) == 1 or dense.shape[0] == 1:
|
|
@@ -24,5 +24,6 @@ from mindspore.ops._grad_experimental import grad_math_ops
|
|
|
24
24
|
from mindspore.ops._grad_experimental import grad_linalg_ops
|
|
25
25
|
from mindspore.ops._grad_experimental import grad_sparse
|
|
26
26
|
from mindspore.ops._grad_experimental import grad_sparse_ops
|
|
27
|
+
from mindspore.ops._grad_experimental import grad_scalar_ops
|
|
27
28
|
|
|
28
29
|
__all__ = ['get_bprop_fn']
|
|
@@ -33,6 +33,7 @@ from mindspore.ops.operations.array_ops import Mvlgamma
|
|
|
33
33
|
from mindspore.ops.operations.array_ops import Triu
|
|
34
34
|
from mindspore.ops.operations.array_ops import IdentityN
|
|
35
35
|
from mindspore.ops.operations.array_ops import IndexFill
|
|
36
|
+
from mindspore.ops.operations.array_ops import IndexPut
|
|
36
37
|
from mindspore.ops.operations.array_ops import CheckNumerics
|
|
37
38
|
from mindspore.ops.operations.array_ops import ConjugateTranspose
|
|
38
39
|
from mindspore.ops.operations.array_ops import SegmentMax
|
|
@@ -47,13 +48,48 @@ from mindspore.ops.operations.array_ops import Im2Col
|
|
|
47
48
|
from mindspore.ops.operations.array_ops import Col2Im
|
|
48
49
|
from mindspore.ops.operations.array_ops import StridedSliceV2
|
|
49
50
|
from mindspore.ops.operations.array_ops import MaskedScatter
|
|
51
|
+
from mindspore.ops.operations.array_ops import MaskedSelect
|
|
52
|
+
from mindspore.ops.operations.array_ops import CountNonZero
|
|
50
53
|
from mindspore.ops.operations._grad_ops import StridedSliceV2Grad
|
|
51
54
|
from mindspore.ops.operations.random_ops import LogNormalReverse
|
|
55
|
+
from mindspore.ops.operations.random_ops import ParameterizedTruncatedNormal
|
|
52
56
|
from mindspore.ops.operations import _inner_ops as inner
|
|
53
57
|
from mindspore.ops import functional as F
|
|
54
58
|
from mindspore.ops import operations as P
|
|
55
|
-
from mindspore.ops._utils.utils import is_shape_unknown
|
|
56
59
|
from mindspore.ops.operations import _grad_ops as G
|
|
60
|
+
from mindspore import context
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@constexpr
|
|
64
|
+
def _raise_value_error(*info):
|
|
65
|
+
info_str = ""
|
|
66
|
+
for obj in info:
|
|
67
|
+
info_str = info_str + f"{obj}"
|
|
68
|
+
raise ValueError(info_str)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@bprop_getters.register(P.FillV2)
|
|
72
|
+
def get_bprop_fill_v2(self):
|
|
73
|
+
"""Generate bprop for FillV2"""
|
|
74
|
+
sum_op = P.ReduceSum()
|
|
75
|
+
cast_op = P.Cast()
|
|
76
|
+
shape_op = P.TensorShape()
|
|
77
|
+
|
|
78
|
+
def bprop(shape, value, out, dout):
|
|
79
|
+
dout_type = F.dtype(dout)
|
|
80
|
+
type_list = [
|
|
81
|
+
mstype.int8, mstype.int16, mstype.int32, mstype.int64,
|
|
82
|
+
mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64,
|
|
83
|
+
mstype.float16, mstype.float64
|
|
84
|
+
]
|
|
85
|
+
if dout_type in type_list:
|
|
86
|
+
dout = cast_op(dout, mstype.float32)
|
|
87
|
+
dout_shape = shape_op(dout)
|
|
88
|
+
axis = tuple([i for i in range(len(dout_shape))])
|
|
89
|
+
dvalue = sum_op(dout, axis)
|
|
90
|
+
return zeros_like(shape), cast_op(dvalue, dout_type)
|
|
91
|
+
|
|
92
|
+
return bprop
|
|
57
93
|
|
|
58
94
|
|
|
59
95
|
@bprop_getters.register(StridedSliceV2)
|
|
@@ -69,7 +105,7 @@ def get_bprop_strided_slice_v2(self):
|
|
|
69
105
|
|
|
70
106
|
def bprop(x, begin, end, strides, out, dout):
|
|
71
107
|
x_shape = shape_op(x)
|
|
72
|
-
if
|
|
108
|
+
if F.is_sequence_value_unknown(x_shape):
|
|
73
109
|
x_shape = dyn_shape_op(x)
|
|
74
110
|
dx = input_grad(x_shape, begin, end, strides, dout)
|
|
75
111
|
dx_all = (dx, zeros_like(begin), zeros_like(end), zeros_like(strides))
|
|
@@ -114,7 +150,12 @@ def get_bprop_masked_select(self):
|
|
|
114
150
|
dinput = mul_op(dout, (1 - mask))
|
|
115
151
|
dvalue = mul_op(dout, mask)
|
|
116
152
|
dinput, dvalue = binop_grad_common(input_data, mask, dinput, dvalue)
|
|
117
|
-
|
|
153
|
+
# for dynamic rank, reduce axis should be calc
|
|
154
|
+
if F.is_sequence_shape_unknown(P.Shape()(dvalue)):
|
|
155
|
+
axis = P.Range()(Tensor(0), dyn_rank(dvalue), Tensor(1))
|
|
156
|
+
dvalue = sum_op(dvalue, axis)
|
|
157
|
+
else:
|
|
158
|
+
dvalue = sum_op(dvalue)
|
|
118
159
|
dinput = F.cast(dinput, F.dtype(input_data))
|
|
119
160
|
if is_instance_op(value, mstype.number):
|
|
120
161
|
dvalue = 0
|
|
@@ -163,6 +204,16 @@ def get_bprop_masked_scatter(self):
|
|
|
163
204
|
return bprop
|
|
164
205
|
|
|
165
206
|
|
|
207
|
+
@bprop_getters.register(CountNonZero)
|
|
208
|
+
def get_bprop_countnonzero(self):
|
|
209
|
+
"""Grad definition for CountNonZero"""
|
|
210
|
+
|
|
211
|
+
def bprop(x, out, dout):
|
|
212
|
+
return (zeros_like(x),)
|
|
213
|
+
|
|
214
|
+
return bprop
|
|
215
|
+
|
|
216
|
+
|
|
166
217
|
@bprop_getters.register(Mvlgamma)
|
|
167
218
|
def get_bprop_mvlgamma(self):
|
|
168
219
|
"""Grad definition for Mvlgamma"""
|
|
@@ -210,7 +261,7 @@ def get_bprop_index_fill(self):
|
|
|
210
261
|
def bprop(x, dim, indices, value, out, dout):
|
|
211
262
|
zero_value = zeros_like(value)
|
|
212
263
|
x_grad = index_fill(dout, dim, indices, zero_value)
|
|
213
|
-
if
|
|
264
|
+
if F.is_sequence_value_unknown(shape(x)):
|
|
214
265
|
if dyn_rank(x) == 0:
|
|
215
266
|
value_grad = dout
|
|
216
267
|
else:
|
|
@@ -226,6 +277,43 @@ def get_bprop_index_fill(self):
|
|
|
226
277
|
return bprop
|
|
227
278
|
|
|
228
279
|
|
|
280
|
+
@bprop_getters.register(IndexPut)
|
|
281
|
+
def get_bprop_index_put(self):
|
|
282
|
+
"""Generate bprop for IndexPut"""
|
|
283
|
+
gather_nd = P.GatherNd()
|
|
284
|
+
stack = P.Stack()
|
|
285
|
+
tile = P.Tile()
|
|
286
|
+
masked_select = MaskedSelect()
|
|
287
|
+
masked_scatter = MaskedScatter()
|
|
288
|
+
accumulate_grad = self.accumulate
|
|
289
|
+
index_put = IndexPut(accumulate=accumulate_grad)
|
|
290
|
+
is_ascend = context.get_context("device_target") == 'Ascend'
|
|
291
|
+
|
|
292
|
+
# Negative value are not supported for GatherNd indices when Ascend, so convert it to positive value.
|
|
293
|
+
def convert_idx_positive(indices_i, x_shape_i):
|
|
294
|
+
mask = indices_i < 0
|
|
295
|
+
idx_pos = masked_select(indices_i + x_shape_i, mask)
|
|
296
|
+
idx = masked_scatter(indices_i, mask, idx_pos)
|
|
297
|
+
return idx
|
|
298
|
+
|
|
299
|
+
def bprop(x1, x2, indices, out, dout):
|
|
300
|
+
maxsize = max(x.shape[0] for x in indices)
|
|
301
|
+
indices_ms = [tile(x, (maxsize,)) if x.shape[0] == 1 else x for x in indices]
|
|
302
|
+
if is_ascend:
|
|
303
|
+
indices_ms = [convert_idx_positive(indices_ms[i], x1.shape[i]) for i in range(len(indices_ms))]
|
|
304
|
+
indices_grad = stack(indices_ms).T
|
|
305
|
+
values_grad = gather_nd(dout, indices_grad)
|
|
306
|
+
if x2.shape[0] == 1:
|
|
307
|
+
values_grad = values_grad.sum().reshape(1)
|
|
308
|
+
if values_grad.shape != x2.shape and len(indices) < len(x1.shape):
|
|
309
|
+
_, values_grad = binop_grad_common(x1, x2, dout, values_grad)
|
|
310
|
+
if accumulate_grad == 0:
|
|
311
|
+
dout = index_put(dout, zeros_like(x2), indices)
|
|
312
|
+
return dout, values_grad, [zeros_like(item) for item in indices]
|
|
313
|
+
|
|
314
|
+
return bprop
|
|
315
|
+
|
|
316
|
+
|
|
229
317
|
@bprop_getters.register(P.TensorScatterSub)
|
|
230
318
|
def get_bprop_tensor_scatter_sub(self):
|
|
231
319
|
"""Generate bprop for TensorScatterSub"""
|
|
@@ -280,7 +368,7 @@ def get_bprop_matrix_diag_part_v3(self):
|
|
|
280
368
|
|
|
281
369
|
def bprop(x, k, padding_value, out, dout):
|
|
282
370
|
shape_this = P.Shape()(x)[-2:]
|
|
283
|
-
if not
|
|
371
|
+
if not F.is_sequence_value_unknown(shape_this):
|
|
284
372
|
row = shape_this[0]
|
|
285
373
|
col = shape_this[1]
|
|
286
374
|
result = (matrix_diag_v3(dout, k, Tensor(row, dtype=mstype.int32), Tensor(col, dtype=mstype.int32),
|
|
@@ -304,7 +392,7 @@ def get_bprop_matrix_set_diag_v3(self):
|
|
|
304
392
|
diagonal_cal = matrix_diag_part_v3(dout, k, zeros((), dout.dtype))
|
|
305
393
|
|
|
306
394
|
diagonal_shape = P.Shape()(diagonal)
|
|
307
|
-
if
|
|
395
|
+
if F.is_sequence_value_unknown(diagonal_shape):
|
|
308
396
|
diagonal = F.cast(diagonal, dout.dtype)
|
|
309
397
|
x_cal = matrix_set_diag_v3(dout, zeros_like(diagonal), k)
|
|
310
398
|
else:
|
|
@@ -327,7 +415,7 @@ def tensor_scatter_possible_replacement(x, indices, updates, out, dout):
|
|
|
327
415
|
possibly_updated = gather_nd(out, indices)
|
|
328
416
|
out_indicators = F.cast(equal(updates, possibly_updated), mstype.int32)
|
|
329
417
|
input_shape = shape(x)
|
|
330
|
-
if
|
|
418
|
+
if F.is_sequence_value_unknown(input_shape):
|
|
331
419
|
input_shape = dyn_shape_op(x)
|
|
332
420
|
|
|
333
421
|
scattered_out_indicators = scatter_nd(indices, out_indicators, input_shape)
|
|
@@ -347,6 +435,15 @@ def get_bprop_log_normal_reverse(self):
|
|
|
347
435
|
return bprop
|
|
348
436
|
|
|
349
437
|
|
|
438
|
+
@bprop_getters.register(ParameterizedTruncatedNormal)
|
|
439
|
+
def get_bprop_parameterized_truncated_normal(self):
|
|
440
|
+
"""Grad definition for `ParameterizedTruncatedNormal` operation."""
|
|
441
|
+
def bprop(shape, mean, stdevs, min_val, max_val, out, dout):
|
|
442
|
+
return (zeros_like(shape), zeros_like(mean), zeros_like(stdevs), zeros_like(min_val), zeros_like(max_val))
|
|
443
|
+
|
|
444
|
+
return bprop
|
|
445
|
+
|
|
446
|
+
|
|
350
447
|
@bprop_getters.register(P.TensorScatterMax)
|
|
351
448
|
def get_bprop_tensor_scatter_max(self):
|
|
352
449
|
"""Generate bprop for TensorScatterMax"""
|
|
@@ -446,13 +543,13 @@ def get_bprop_resize_nearest_neighbor_v2(self):
|
|
|
446
543
|
|
|
447
544
|
def bprop(x, size, output, dout):
|
|
448
545
|
x_shape = P.Shape()(x)
|
|
449
|
-
if
|
|
546
|
+
if F.is_sequence_value_unknown(x_shape):
|
|
450
547
|
x_shape = P.TensorShape()(x)
|
|
451
548
|
grad_in_size = x_shape[1:3]
|
|
452
549
|
if data_format == 'NCHW':
|
|
453
550
|
grad_in_size = x_shape[2:4]
|
|
454
551
|
|
|
455
|
-
if
|
|
552
|
+
if F.is_sequence_value_unknown(P.Shape()(x)):
|
|
456
553
|
dx = grad_op(dout, grad_in_size)
|
|
457
554
|
return dx, zeros_like(grad_in_size)
|
|
458
555
|
|
|
@@ -469,7 +566,7 @@ def get_bprop_col2im(self):
|
|
|
469
566
|
dilations = self.dilation
|
|
470
567
|
strides = self.stride
|
|
471
568
|
pads = self.padding
|
|
472
|
-
im2col = Im2Col(ksizes=ksizes, dilations=dilations, strides=strides,
|
|
569
|
+
im2col = Im2Col(ksizes=ksizes, dilations=dilations, strides=strides, pads=pads)
|
|
473
570
|
|
|
474
571
|
def bprop(x, output_size, out, dout):
|
|
475
572
|
dx = im2col(dout)
|
|
@@ -478,6 +575,36 @@ def get_bprop_col2im(self):
|
|
|
478
575
|
return bprop
|
|
479
576
|
|
|
480
577
|
|
|
578
|
+
@bprop_getters.register(Im2Col)
|
|
579
|
+
def get_bprop_im2col(self):
|
|
580
|
+
"""
|
|
581
|
+
Generate bprop for Im2Col
|
|
582
|
+
|
|
583
|
+
Im2Col, corresponding to torch's UnFold operator.
|
|
584
|
+
The Unfold operator has no `padding_mode` attribute,
|
|
585
|
+
and it's implementation corresponds to the mindspore
|
|
586
|
+
implementation with `padding_mode=CALCULATED` .
|
|
587
|
+
So, currently the bprop function of Im2Col only supports
|
|
588
|
+
the CALCULATED mode.
|
|
589
|
+
"""
|
|
590
|
+
kernel_size = self.ksizes
|
|
591
|
+
dilation = self.dilations
|
|
592
|
+
stride = self.strides
|
|
593
|
+
padding = (self.pads[0], self.pads[-1])
|
|
594
|
+
shape_op = P.TensorShape()
|
|
595
|
+
col2im = Col2Im(kernel_size=kernel_size,
|
|
596
|
+
dilation=dilation,
|
|
597
|
+
stride=stride,
|
|
598
|
+
padding=padding)
|
|
599
|
+
|
|
600
|
+
def bprop(x, out, dout):
|
|
601
|
+
x_shape = shape_op(x)[2:]
|
|
602
|
+
dx = col2im(dout, x_shape)
|
|
603
|
+
return (dx,)
|
|
604
|
+
|
|
605
|
+
return bprop
|
|
606
|
+
|
|
607
|
+
|
|
481
608
|
@bprop_getters.register(P.ExtractVolumePatches)
|
|
482
609
|
def get_bprop_extract_volume_patches(self):
|
|
483
610
|
"""Generate bprop for ExtractVolumePatches"""
|
|
@@ -538,7 +665,7 @@ def get_bprop_extract_volume_patches(self):
|
|
|
538
665
|
def bprop(x, out, dout):
|
|
539
666
|
x_shape = P.Shape()(x)
|
|
540
667
|
out_shape = P.Shape()(out)
|
|
541
|
-
if
|
|
668
|
+
if F.is_sequence_value_unknown(x_shape) or F.is_sequence_value_unknown(out_shape):
|
|
542
669
|
return _dyn_extract_volume_patches(x, out, dout)
|
|
543
670
|
x_n, x_c, x_d, x_h, x_w = x_shape
|
|
544
671
|
x_indices_num = 1 + x_d * x_h * x_w
|
|
@@ -609,6 +736,7 @@ def get_bprop_affinegrid(self):
|
|
|
609
736
|
"""Generate bprop for AffineGrid"""
|
|
610
737
|
|
|
611
738
|
align_corners = self.align_corners
|
|
739
|
+
input_grad = G.AffineGridGrad(align_corners)
|
|
612
740
|
ones = P.Ones()
|
|
613
741
|
transpose = P.Transpose()
|
|
614
742
|
concat = P.Concat(1)
|
|
@@ -824,12 +952,19 @@ def get_bprop_affinegrid(self):
|
|
|
824
952
|
dtheta = transpose(dtheta, perm2)
|
|
825
953
|
return dtheta, tre
|
|
826
954
|
|
|
827
|
-
def
|
|
955
|
+
def bprop_gpu(theta, output_size, out, dout):
|
|
828
956
|
is_tensor, _ = convert_to_tensor(output_size)
|
|
829
957
|
if is_tensor:
|
|
830
958
|
return dyn_bprop(theta, output_size, out, dout)
|
|
831
959
|
return static_bprop(theta, output_size, out, dout)
|
|
832
960
|
|
|
961
|
+
def bprop(theta, output_size, out, dout):
|
|
962
|
+
dx = input_grad(dout, output_size)
|
|
963
|
+
return dx, zeros_like(output_size)
|
|
964
|
+
|
|
965
|
+
if context.get_context('device_target') == "GPU":
|
|
966
|
+
return bprop_gpu
|
|
967
|
+
|
|
833
968
|
return bprop
|
|
834
969
|
|
|
835
970
|
|
|
@@ -907,7 +1042,7 @@ def get_bprop_expand(self):
|
|
|
907
1042
|
|
|
908
1043
|
def bprop(x, shape, out, dout):
|
|
909
1044
|
reduce_dims = []
|
|
910
|
-
dshape = zeroslike(
|
|
1045
|
+
dshape = zeroslike(shape)
|
|
911
1046
|
dx_shape = dout.shape
|
|
912
1047
|
if dx_shape is None:
|
|
913
1048
|
return dout.sum(), dshape
|
|
@@ -947,12 +1082,12 @@ def get_bprop_segment_mean(self):
|
|
|
947
1082
|
dout_type = F.dtype(dout)
|
|
948
1083
|
|
|
949
1084
|
ones_shape = shape(segment_ids)
|
|
950
|
-
if
|
|
1085
|
+
if F.is_sequence_value_unknown(ones_shape):
|
|
951
1086
|
ones_shape = dyn_shape(segment_ids)
|
|
952
1087
|
|
|
953
1088
|
ones = ()
|
|
954
1089
|
inputx_shape = shape(input_x)
|
|
955
|
-
if
|
|
1090
|
+
if F.is_sequence_value_unknown(inputx_shape):
|
|
956
1091
|
input_rank = dyn_rank(input_x)
|
|
957
1092
|
if input_rank > cast(1, mstype.float32):
|
|
958
1093
|
ones_shape = concat([ones_shape, dyn_ones(expand_dims(input_rank - 1, 0), mstype.int64)])
|