mindspore 2.0.0rc1__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu10.1/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +647 -818
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
|
@@ -1,1684 +0,0 @@
|
|
|
1
|
-
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
|
|
16
|
-
"""Define the grad rules of math related operations."""
|
|
17
|
-
|
|
18
|
-
import numpy as np
|
|
19
|
-
import mindspore as ms
|
|
20
|
-
from mindspore import nn
|
|
21
|
-
from mindspore.common import Tensor
|
|
22
|
-
from mindspore.common import dtype as mstype
|
|
23
|
-
from mindspore.ops import functional as F
|
|
24
|
-
from mindspore.ops import operations as P
|
|
25
|
-
from mindspore.ops.operations import _grad_ops as G
|
|
26
|
-
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
27
|
-
from mindspore.ops.functional import broadcast_gradient_args, reduced_shape, tuple_div
|
|
28
|
-
from mindspore.ops._grad.grad_base import bprop_getters, create_tensor_by_element, dyn_invert_permutation
|
|
29
|
-
from mindspore.ops._grad.grad_base import convert_to_tensor
|
|
30
|
-
from mindspore.ops._grad.grad_base import sum_grad_reduce_axis, dyn_fill, dyn_rank
|
|
31
|
-
from mindspore.ops._grad.grad_base import dyn_ones, dyn_rank_1d
|
|
32
|
-
from mindspore.ops.primitive import _primexpr
|
|
33
|
-
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
|
|
34
|
-
from mindspore.ops.operations._inner_ops import DynamicBroadcastGradientArgs, IsSubClass, DynamicBroadcastTo
|
|
35
|
-
from mindspore.ops.operations import array_ops as A
|
|
36
|
-
|
|
37
|
-
shape_op = P.Shape()
|
|
38
|
-
dyn_shape_op = P.TensorShape()
|
|
39
|
-
reduce_prod = P.ReduceProd()
|
|
40
|
-
reduce_sum = P.ReduceSum()
|
|
41
|
-
reshape = P.Reshape()
|
|
42
|
-
tile = P.Tile()
|
|
43
|
-
is_sub_class = IsSubClass()
|
|
44
|
-
to_array = P.TupleToArray()
|
|
45
|
-
real_div = P.RealDiv()
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
def dyn_binop_grad_common(x, y, dx, dy):
|
|
49
|
-
"""
|
|
50
|
-
Common grad definition for binary operations when the input is dynamic shape.
|
|
51
|
-
|
|
52
|
-
The function is usually used in backprop op to reduce additional dimensions created by broadcasting.
|
|
53
|
-
"""
|
|
54
|
-
shape_of_x = dyn_shape_op(x)
|
|
55
|
-
shape_of_y = dyn_shape_op(y)
|
|
56
|
-
rx, ry = DynamicBroadcastGradientArgs()(shape_of_x, shape_of_y)
|
|
57
|
-
dx_origin_dtype = dx.dtype
|
|
58
|
-
if dx_origin_dtype in (mstype.int16, mstype.int32, mstype.int64):
|
|
59
|
-
dx = F.cast(dx, mstype.float32)
|
|
60
|
-
dx = sum_grad_reduce_axis(dx, rx)
|
|
61
|
-
dx = F.cast(dx, dx_origin_dtype)
|
|
62
|
-
else:
|
|
63
|
-
dx = sum_grad_reduce_axis(dx, rx)
|
|
64
|
-
dy_origin_dtype = dy.dtype
|
|
65
|
-
if dy_origin_dtype in (mstype.int16, mstype.int32, mstype.int64):
|
|
66
|
-
dy = F.cast(dy, mstype.float32)
|
|
67
|
-
dy = sum_grad_reduce_axis(dy, ry)
|
|
68
|
-
dy = F.cast(dy, dy_origin_dtype)
|
|
69
|
-
else:
|
|
70
|
-
dy = sum_grad_reduce_axis(dy, ry)
|
|
71
|
-
reduce_dx = reshape(dx, shape_of_x)
|
|
72
|
-
reduce_dy = reshape(dy, shape_of_y)
|
|
73
|
-
return reduce_dx, reduce_dy
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
def dyn_binop_grad_common_with_shift(x, y, dx, dy, shift):
|
|
77
|
-
"""
|
|
78
|
-
Common grad definition for binary operations with shift when the input is dynamic shape.
|
|
79
|
-
|
|
80
|
-
The function is usually used in backprop op to reduce additional dimensions created by broadcasting.
|
|
81
|
-
"""
|
|
82
|
-
shape_of_x = dyn_shape_op(x)
|
|
83
|
-
shape_of_y = dyn_shape_op(y)
|
|
84
|
-
broadcast_shape_of_x = shape_of_x[:-shift]
|
|
85
|
-
broadcast_shape_of_y = shape_of_y[:-shift]
|
|
86
|
-
rx, ry = DynamicBroadcastGradientArgs()(broadcast_shape_of_x, broadcast_shape_of_y)
|
|
87
|
-
dx = sum_grad_reduce_axis(dx, rx)
|
|
88
|
-
dy = sum_grad_reduce_axis(dy, ry)
|
|
89
|
-
reduce_dx = reshape(dx, shape_of_x)
|
|
90
|
-
reduce_dy = reshape(dy, shape_of_y)
|
|
91
|
-
return reduce_dx, reduce_dy
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
def _reduce_sum_with_cast(dx, axis):
|
|
95
|
-
dx_origin_dtype = dx.dtype
|
|
96
|
-
# Currently, for Ascend and GPU, the reduce_sum's input does not support int16, int32 and int64.
|
|
97
|
-
if dx_origin_dtype in (mstype.int16, mstype.int32, mstype.int64):
|
|
98
|
-
dx = F.cast(dx, mstype.float32)
|
|
99
|
-
dx = reduce_sum(dx, axis)
|
|
100
|
-
return F.cast(dx, dx_origin_dtype)
|
|
101
|
-
return reduce_sum(dx, axis)
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
def binop_grad_common(x, y, dx, dy):
|
|
105
|
-
"""
|
|
106
|
-
Common grad definition for binary operations.
|
|
107
|
-
|
|
108
|
-
The function is usually used in backprop op to reduce additional dimensions created by broadcasting.
|
|
109
|
-
"""
|
|
110
|
-
shape_of_x = shape_op(x)
|
|
111
|
-
shape_of_y = shape_op(y)
|
|
112
|
-
# if input shape is the same as dout shape, do not need to reduce
|
|
113
|
-
reduce_dx = dx
|
|
114
|
-
reduce_dy = dy
|
|
115
|
-
if not (F.is_sequence_value_unknown(shape_of_x) or F.is_sequence_value_unknown(shape_of_y)):
|
|
116
|
-
rx = broadcast_gradient_args(shape_of_x, shape_of_y)
|
|
117
|
-
if rx[0]:
|
|
118
|
-
# if dx is scalar whose shape is (), do not need reduce
|
|
119
|
-
if shape_op(dx):
|
|
120
|
-
dx = _reduce_sum_with_cast(dx, rx[0])
|
|
121
|
-
reduce_dx = reshape(dx, shape_of_x)
|
|
122
|
-
if rx[1]:
|
|
123
|
-
# if dy is scalar whose shape is (), do not need reduce
|
|
124
|
-
if shape_op(dy):
|
|
125
|
-
dy = _reduce_sum_with_cast(dy, rx[1])
|
|
126
|
-
reduce_dy = reshape(dy, shape_of_y)
|
|
127
|
-
return reduce_dx, reduce_dy
|
|
128
|
-
|
|
129
|
-
if not isinstance(shape_of_x, tuple) or not isinstance(shape_of_y, tuple):
|
|
130
|
-
# x or y is scalar
|
|
131
|
-
if not isinstance(shape_of_x, tuple):
|
|
132
|
-
reduce_dx = _reduce_sum_with_cast(dx, ())
|
|
133
|
-
if not isinstance(shape_of_y, tuple):
|
|
134
|
-
reduce_dy = _reduce_sum_with_cast(dy, ())
|
|
135
|
-
return reduce_dx, reduce_dy
|
|
136
|
-
|
|
137
|
-
return dyn_binop_grad_common(x, y, dx, dy)
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
def binop_grad_common_with_shift(x, y, dx, dy, shift):
|
|
141
|
-
"""
|
|
142
|
-
Common grad definition for binary operations with shift.
|
|
143
|
-
|
|
144
|
-
The function is usually used in backprop op to reduce additional dimensions created by broadcasting.
|
|
145
|
-
"""
|
|
146
|
-
shape_of_x = shape_op(x)
|
|
147
|
-
shape_of_y = shape_op(y)
|
|
148
|
-
broadcast_shape_of_x = shape_of_x[:-shift]
|
|
149
|
-
broadcast_shape_of_y = shape_of_y[:-shift]
|
|
150
|
-
# if input shape is the same as dout shape, do not need to reduce
|
|
151
|
-
reduce_dx = dx
|
|
152
|
-
reduce_dy = dy
|
|
153
|
-
if not (F.is_sequence_value_unknown(broadcast_shape_of_x) or F.is_sequence_value_unknown(broadcast_shape_of_y)):
|
|
154
|
-
rx = broadcast_gradient_args(broadcast_shape_of_x, broadcast_shape_of_y)
|
|
155
|
-
if rx[0]:
|
|
156
|
-
# if dx is scalar whose shape is (), do not need reduce
|
|
157
|
-
if shape_op(dx):
|
|
158
|
-
dx = _reduce_sum_with_cast(dx, rx[0])
|
|
159
|
-
reduce_dx = reshape(dx, shape_of_x)
|
|
160
|
-
if rx[1]:
|
|
161
|
-
# if dy is scalar whose shape is (), do not need reduce
|
|
162
|
-
if shape_op(dy):
|
|
163
|
-
dy = _reduce_sum_with_cast(dy, rx[1])
|
|
164
|
-
reduce_dy = reshape(dy, shape_of_y)
|
|
165
|
-
return reduce_dx, reduce_dy
|
|
166
|
-
|
|
167
|
-
if not isinstance(shape_of_x, tuple) or not isinstance(shape_of_y, tuple):
|
|
168
|
-
# x or y is scalar
|
|
169
|
-
if not isinstance(shape_of_x, tuple):
|
|
170
|
-
reduce_dx = _reduce_sum_with_cast(dx, ())
|
|
171
|
-
if not isinstance(shape_of_y, tuple):
|
|
172
|
-
reduce_dy = _reduce_sum_with_cast(dy, ())
|
|
173
|
-
return reduce_dx, reduce_dy
|
|
174
|
-
|
|
175
|
-
return dyn_binop_grad_common_with_shift(x, y, dx, dy, shift)
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
def _dyn_reduced_shape(input_shape, axis, x):
|
|
179
|
-
"""Dynamic reduce shape"""
|
|
180
|
-
input_shape = P.Cast()(input_shape, ms.int32)
|
|
181
|
-
if x is not None and not F.is_sequence_shape_unknown(shape_op(x)):
|
|
182
|
-
input_rank = len(shape_op(x))
|
|
183
|
-
else:
|
|
184
|
-
input_rank = dyn_rank(x)
|
|
185
|
-
input_rank = P.Cast()(input_rank, ms.int32)
|
|
186
|
-
|
|
187
|
-
if (isinstance(axis, tuple) and axis == ()) or (isinstance(axis, list) and axis == []):
|
|
188
|
-
res_shape = P.ExpandDims()(input_rank, 0)
|
|
189
|
-
return dyn_ones(res_shape, res_shape.dtype)
|
|
190
|
-
|
|
191
|
-
if isinstance(axis, int):
|
|
192
|
-
axis = (axis,)
|
|
193
|
-
|
|
194
|
-
real_axis = axis
|
|
195
|
-
if not isinstance(axis, Tensor):
|
|
196
|
-
real_axis = Tensor(axis, ms.int32)
|
|
197
|
-
|
|
198
|
-
real_axis = (real_axis + input_rank) % input_rank
|
|
199
|
-
if real_axis.ndim == 0:
|
|
200
|
-
real_axis = P.ExpandDims()(real_axis, 0)
|
|
201
|
-
expanded_axis = P.ExpandDims()(real_axis, 1)
|
|
202
|
-
expanded_axis = P.Cast()(expanded_axis, ms.int32)
|
|
203
|
-
update = P.Cast()(P.OnesLike()(real_axis), ms.float32)
|
|
204
|
-
input_shape = P.Cast()(input_shape, ms.float32)
|
|
205
|
-
return P.TensorScatterUpdate()(input_shape, expanded_axis, update)
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
def _sum_grad(x, axis, dout):
|
|
209
|
-
"""Grad definition for `Sum` operation."""
|
|
210
|
-
input_shape = shape_op(x)
|
|
211
|
-
is_mutable, axis = convert_to_tensor(axis)
|
|
212
|
-
if F.is_sequence_value_unknown(input_shape) or is_mutable:
|
|
213
|
-
input_shape = dyn_shape_op(x)
|
|
214
|
-
output_shape_kept_dims = _dyn_reduced_shape(input_shape, axis, x)
|
|
215
|
-
output_shape_kept_dims = P.Cast()(output_shape_kept_dims, ms.int32)
|
|
216
|
-
grad = reshape(dout, output_shape_kept_dims)
|
|
217
|
-
return DynamicBroadcastTo()(grad, input_shape)
|
|
218
|
-
|
|
219
|
-
output_shape_kept_dims = reduced_shape(input_shape, axis)
|
|
220
|
-
tile_scaling = tuple_div(input_shape, output_shape_kept_dims)
|
|
221
|
-
grad = reshape(dout, output_shape_kept_dims)
|
|
222
|
-
return tile(grad, tile_scaling)
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
def _min_or_max_grad(x, axis, out, dout):
|
|
226
|
-
"""Grad definition for `Min` and `Max` operations."""
|
|
227
|
-
input_shape = shape_op(x)
|
|
228
|
-
output_shape_kept_dims = ()
|
|
229
|
-
if F.is_sequence_value_unknown(input_shape):
|
|
230
|
-
input_shape = dyn_shape_op(x)
|
|
231
|
-
output_shape_kept_dims = _dyn_reduced_shape(input_shape, axis, x)
|
|
232
|
-
output_shape_kept_dims = P.Cast()(output_shape_kept_dims, ms.int32)
|
|
233
|
-
else:
|
|
234
|
-
output_shape_kept_dims = reduced_shape(input_shape, axis)
|
|
235
|
-
|
|
236
|
-
y = reshape(out, output_shape_kept_dims)
|
|
237
|
-
grad = reshape(dout, output_shape_kept_dims)
|
|
238
|
-
indicators = F.cast(F.equal(y, x), F.dtype(grad))
|
|
239
|
-
min_num = F.cast(F.scalar_to_tensor(1e-24), F.dtype(grad))
|
|
240
|
-
num_selected = reshape(reduce_sum(indicators, axis), output_shape_kept_dims) + min_num
|
|
241
|
-
return indicators / num_selected * grad
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
def _onehot_with_neg_axis(axis, indices, depth, on_value_dtype):
|
|
245
|
-
"""onehot support tensor axis"""
|
|
246
|
-
depth_range = P.Range()(F.cast(0, depth.dtype), depth, F.cast(1, depth.dtype))
|
|
247
|
-
indices_expand = P.ExpandDims()(indices, axis)
|
|
248
|
-
indices_expand_rank = dyn_rank_1d(indices_expand)
|
|
249
|
-
broad_shape = dyn_ones(indices_expand_rank, mstype.int64)
|
|
250
|
-
# It should use int64 dtype, but the TensorScatterUpdate op does not support the int64
|
|
251
|
-
# dtype on Ascend device, so the float32 dtype is used here.
|
|
252
|
-
update_dtype = mstype.float32
|
|
253
|
-
broad_shape = dyn_ones(indices_expand_rank, update_dtype)
|
|
254
|
-
broad_shape[axis] = F.cast(depth, update_dtype)
|
|
255
|
-
broad_shape = F.cast(broad_shape, mstype.int64)
|
|
256
|
-
depth_broad = P.Reshape()(depth_range, broad_shape)
|
|
257
|
-
one_hot_bool = P.Equal()(indices_expand, depth_broad)
|
|
258
|
-
one_hot_res = F.cast(one_hot_bool, on_value_dtype)
|
|
259
|
-
return one_hot_res
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
def _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout):
|
|
263
|
-
"""ArgMinWiwhValue and ArgMaxWithValue grad."""
|
|
264
|
-
expand = P.ExpandDims()
|
|
265
|
-
squeeze = P.Squeeze()
|
|
266
|
-
x_shape = F.shape(x)
|
|
267
|
-
x_dim = len(x_shape)
|
|
268
|
-
x_axis = axis
|
|
269
|
-
onehot_axis_is_neg = False
|
|
270
|
-
if x_axis < 0:
|
|
271
|
-
if not F.is_sequence_shape_unknown(x_shape):
|
|
272
|
-
x_axis = axis + x_dim
|
|
273
|
-
else:
|
|
274
|
-
onehot_axis_is_neg = True
|
|
275
|
-
onehot_axis = x_axis
|
|
276
|
-
if keep_dims:
|
|
277
|
-
dout_expand = dout[1]
|
|
278
|
-
out = op(x)
|
|
279
|
-
else:
|
|
280
|
-
dout_expand = expand(dout[1], onehot_axis)
|
|
281
|
-
out_shape = shape_op(out[0])
|
|
282
|
-
if not F.is_sequence_shape_unknown(out_shape):
|
|
283
|
-
if onehot_axis >= len(out_shape):
|
|
284
|
-
onehot_axis = -1
|
|
285
|
-
type_x = F.dtype(x)
|
|
286
|
-
on_value = F.cast(F.scalar_to_tensor(1.0), type_x)
|
|
287
|
-
off_value = F.cast(F.scalar_to_tensor(0.0), type_x)
|
|
288
|
-
if not F.is_sequence_value_unknown(x_shape):
|
|
289
|
-
depth = 1
|
|
290
|
-
if x_shape:
|
|
291
|
-
depth = x_shape[axis]
|
|
292
|
-
onehot = P.OneHot(onehot_axis)
|
|
293
|
-
dx = dout_expand * onehot(out[0], depth, on_value, off_value)
|
|
294
|
-
if not x_shape:
|
|
295
|
-
dx = squeeze(dx)
|
|
296
|
-
return dx
|
|
297
|
-
x_tensor_shape = P.TensorShape()(x)
|
|
298
|
-
depth = x_tensor_shape[axis]
|
|
299
|
-
if not onehot_axis_is_neg:
|
|
300
|
-
onehot = P.OneHot(onehot_axis)
|
|
301
|
-
dx = dout_expand * onehot(out[0], depth, on_value, off_value)
|
|
302
|
-
else:
|
|
303
|
-
if out[0].value is not None:
|
|
304
|
-
# It is a temporary method: In the pynative mode, out may be a constant tensor. Constant
|
|
305
|
-
# folding occurs in ExpandDims op, but such scenarios are not supported currently.
|
|
306
|
-
out = op(x)
|
|
307
|
-
dx = dout_expand * _onehot_with_neg_axis(onehot_axis, out[0], depth, on_value.dtype)
|
|
308
|
-
return dx
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
@bprop_getters.register(P.BatchMatMul)
|
|
312
|
-
def bprop_batchmatmul(self):
|
|
313
|
-
"""Grad definition for `BatchMatMul` operation."""
|
|
314
|
-
ta = self.transpose_a
|
|
315
|
-
tb = self.transpose_b
|
|
316
|
-
mul1 = P.BatchMatMul(transpose_a=(ta and tb),
|
|
317
|
-
transpose_b=(ta or (not tb)))
|
|
318
|
-
mul2 = P.BatchMatMul(transpose_a=((not ta) or tb),
|
|
319
|
-
transpose_b=(ta and tb))
|
|
320
|
-
|
|
321
|
-
def bprop(x, w, out, dout):
|
|
322
|
-
if ta:
|
|
323
|
-
dx = mul1(w, dout)
|
|
324
|
-
else:
|
|
325
|
-
dx = mul1(dout, w)
|
|
326
|
-
if tb:
|
|
327
|
-
dw = mul2(dout, x)
|
|
328
|
-
else:
|
|
329
|
-
dw = mul2(x, dout)
|
|
330
|
-
return binop_grad_common_with_shift(x, w, dx, dw, 2)
|
|
331
|
-
|
|
332
|
-
return bprop
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
@bprop_getters.register(P.TensorAdd)
|
|
336
|
-
def get_bprop_tensor_add(self):
|
|
337
|
-
"""Grad definition for `Add` operation."""
|
|
338
|
-
|
|
339
|
-
def bprop(x, y, out, dout):
|
|
340
|
-
return binop_grad_common(x, y, dout, dout)
|
|
341
|
-
|
|
342
|
-
return bprop
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
@bprop_getters.register(P.MatrixInverse)
|
|
346
|
-
def get_bprop_matrix_inverse(self):
|
|
347
|
-
"""Grad definition for `MatrixInverse` operation."""
|
|
348
|
-
matmul_x1 = nn.MatMul(transpose_x1=True)
|
|
349
|
-
matmul_x2 = nn.MatMul(transpose_x2=True)
|
|
350
|
-
neg = P.Neg()
|
|
351
|
-
|
|
352
|
-
def bprop(x, out, dout):
|
|
353
|
-
dx = matmul_x2(dout, out)
|
|
354
|
-
dx = matmul_x1(out, dx)
|
|
355
|
-
dx = neg(dx)
|
|
356
|
-
return (dx,)
|
|
357
|
-
|
|
358
|
-
return bprop
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
@bprop_getters.register(P.Mul)
|
|
362
|
-
def get_bprop_mul(self):
|
|
363
|
-
"""Grad definition for `Mul` operation."""
|
|
364
|
-
mul_func = P.Mul()
|
|
365
|
-
|
|
366
|
-
def bprop(x, y, out, dout):
|
|
367
|
-
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
368
|
-
raise TypeError("For 'Mul', gradient not support for complex type currently.")
|
|
369
|
-
bc_dx = mul_func(y, dout)
|
|
370
|
-
bc_dy = mul_func(x, dout)
|
|
371
|
-
return binop_grad_common(x, y, bc_dx, bc_dy)
|
|
372
|
-
|
|
373
|
-
return bprop
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
@bprop_getters.register(P.RealDiv)
|
|
377
|
-
def get_bprop_real_div(self):
|
|
378
|
-
"""Grad definition for `RealDiv` operation."""
|
|
379
|
-
div_op = P.RealDiv()
|
|
380
|
-
neg = P.Neg()
|
|
381
|
-
mul_op = P.Mul()
|
|
382
|
-
|
|
383
|
-
def bprop(x, y, out, dout):
|
|
384
|
-
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
385
|
-
raise TypeError("For 'RealDiv', gradient not support for complex type currently.")
|
|
386
|
-
bc_x = div_op(dout, y)
|
|
387
|
-
bc_y = neg(mul_op(bc_x, out))
|
|
388
|
-
return binop_grad_common(x, y, bc_x, bc_y)
|
|
389
|
-
|
|
390
|
-
return bprop
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
@bprop_getters.register(P.Div)
|
|
394
|
-
def get_bprop_div(self):
|
|
395
|
-
"""Grad definition for `Div` operation."""
|
|
396
|
-
div_op = P.Div()
|
|
397
|
-
neg = P.Neg()
|
|
398
|
-
mul_op = P.Mul()
|
|
399
|
-
|
|
400
|
-
def bprop(x, y, out, dout):
|
|
401
|
-
bc_x = div_op(dout, y)
|
|
402
|
-
bc_y = neg(mul_op(bc_x, out))
|
|
403
|
-
return binop_grad_common(x, y, bc_x, bc_y)
|
|
404
|
-
|
|
405
|
-
return bprop
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
@bprop_getters.register(P.DivNoNan)
|
|
409
|
-
def get_bprop_div_no_nan(self):
|
|
410
|
-
"""Grad definition for `DivNoNan` operation."""
|
|
411
|
-
div_no_nan_op = P.DivNoNan()
|
|
412
|
-
neg = P.Neg()
|
|
413
|
-
mul_op = P.Mul()
|
|
414
|
-
|
|
415
|
-
def bprop(x, y, out, dout):
|
|
416
|
-
bc_x = div_no_nan_op(dout, y)
|
|
417
|
-
bc_y = neg(mul_op(bc_x, out))
|
|
418
|
-
return binop_grad_common(x, y, bc_x, bc_y)
|
|
419
|
-
|
|
420
|
-
return bprop
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
@bprop_getters.register(P.Xdivy)
|
|
424
|
-
def get_bprop_xdivy(self):
|
|
425
|
-
"""Grad definition for `Xdivy` operation."""
|
|
426
|
-
div_op = P.Xdivy()
|
|
427
|
-
|
|
428
|
-
def bprop(x, y, out, dout):
|
|
429
|
-
x_dtype = F.dtype(x)
|
|
430
|
-
not_zero_x = F.cast(F.not_equal(x, F.cast(0.0, x_dtype)), x_dtype)
|
|
431
|
-
bc_x = div_op(not_zero_x, y) * dout
|
|
432
|
-
bc_y = div_op(-x, F.square(y)) * dout
|
|
433
|
-
return binop_grad_common(x, y, bc_x, bc_y)
|
|
434
|
-
|
|
435
|
-
return bprop
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
@bprop_getters.register(P.Floor)
|
|
439
|
-
def get_bprop_floor(self):
|
|
440
|
-
"""Grad definition for `floor` operation."""
|
|
441
|
-
fill_ = P.Fill()
|
|
442
|
-
shape_ = P.Shape()
|
|
443
|
-
dtype_ = P.DType()
|
|
444
|
-
|
|
445
|
-
def bprop(x, out, dout):
|
|
446
|
-
if F.is_sequence_value_unknown(shape_(x)):
|
|
447
|
-
bc_x = zeros_like(x)
|
|
448
|
-
else:
|
|
449
|
-
bc_x = fill_(dtype_(x), shape_(x), 0.)
|
|
450
|
-
return (bc_x,)
|
|
451
|
-
|
|
452
|
-
return bprop
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
@bprop_getters.register(P.Ceil)
|
|
456
|
-
def get_bprop_ceil(self):
|
|
457
|
-
"""Grad definition for `ceil` operation."""
|
|
458
|
-
fill_ = P.Fill()
|
|
459
|
-
shape_ = P.Shape()
|
|
460
|
-
dtype_ = P.DType()
|
|
461
|
-
|
|
462
|
-
def bprop(x, out, dout):
|
|
463
|
-
if F.is_sequence_value_unknown(shape_(x)):
|
|
464
|
-
bc_x = zeros_like(x)
|
|
465
|
-
else:
|
|
466
|
-
bc_x = fill_(dtype_(x), shape_(x), 0.)
|
|
467
|
-
return (bc_x,)
|
|
468
|
-
|
|
469
|
-
return bprop
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
@bprop_getters.register(P.FloorDiv)
|
|
473
|
-
def get_bprop_floordiv(self):
|
|
474
|
-
"""Grad definition for `FloorDiv` operation."""
|
|
475
|
-
|
|
476
|
-
def bprop(x, y, out, dout):
|
|
477
|
-
return zeros_like(x), zeros_like(y)
|
|
478
|
-
|
|
479
|
-
return bprop
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
@bprop_getters.register(P.BitwiseAnd)
|
|
483
|
-
def get_bprop_bitwiseand(self):
|
|
484
|
-
"""Grad definition for `BitwiseAnd` operation."""
|
|
485
|
-
|
|
486
|
-
def bprop(x, y, out, dout):
|
|
487
|
-
return zeros_like(x), zeros_like(y)
|
|
488
|
-
|
|
489
|
-
return bprop
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
@bprop_getters.register(P.BitwiseOr)
|
|
493
|
-
def get_bprop_bitwiseor(self):
|
|
494
|
-
"""Grad definition for `BitwiseOr` operation."""
|
|
495
|
-
|
|
496
|
-
def bprop(x, y, out, dout):
|
|
497
|
-
return zeros_like(x), zeros_like(y)
|
|
498
|
-
|
|
499
|
-
return bprop
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
@bprop_getters.register(P.BitwiseXor)
|
|
503
|
-
def get_bprop_bitwisexor(self):
|
|
504
|
-
"""Grad definition for `BitwiseXor` operation."""
|
|
505
|
-
|
|
506
|
-
def bprop(x, y, out, dout):
|
|
507
|
-
return zeros_like(x), zeros_like(y)
|
|
508
|
-
|
|
509
|
-
return bprop
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
@bprop_getters.register(P.FloorMod)
|
|
513
|
-
def get_bprop_floormod(self):
|
|
514
|
-
"""Grad definition for `FloorMod` operation."""
|
|
515
|
-
|
|
516
|
-
def bprop(x, y, out, dout):
|
|
517
|
-
bc_x = dout
|
|
518
|
-
bc_y = -dout * (x // y)
|
|
519
|
-
return binop_grad_common(x, y, bc_x, bc_y)
|
|
520
|
-
|
|
521
|
-
return bprop
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
@bprop_getters.register(P.TruncateDiv)
|
|
525
|
-
def get_bprop_truncate_div(self):
|
|
526
|
-
"""Grad definition for `TruncateDiv` operation."""
|
|
527
|
-
|
|
528
|
-
def bprop(x, y, out, dout):
|
|
529
|
-
return zeros_like(x), zeros_like(y)
|
|
530
|
-
|
|
531
|
-
return bprop
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
@bprop_getters.register(P.TruncateMod)
|
|
535
|
-
def get_bprop_truncate_mod(self):
|
|
536
|
-
"""Grad definition for `TruncateMod` operation."""
|
|
537
|
-
div_op = P.TruncateDiv()
|
|
538
|
-
|
|
539
|
-
def bprop(x, y, out, dout):
|
|
540
|
-
bc_x = dout
|
|
541
|
-
bc_y = -dout * div_op(x, y)
|
|
542
|
-
return binop_grad_common(x, y, bc_x, bc_y)
|
|
543
|
-
|
|
544
|
-
return bprop
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
@bprop_getters.register(P.Mod)
|
|
548
|
-
def get_bprop_mod(self):
|
|
549
|
-
"""Grad definition for `Mod` operation."""
|
|
550
|
-
|
|
551
|
-
def bprop(x, y, out, dout):
|
|
552
|
-
bc_x = dout
|
|
553
|
-
bc_y = -dout * (x // y)
|
|
554
|
-
return binop_grad_common(x, y, bc_x, bc_y)
|
|
555
|
-
|
|
556
|
-
return bprop
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
@bprop_getters.register(P.Square)
|
|
560
|
-
def get_bprop_square(self):
|
|
561
|
-
"""Grad definition for `Square` operation."""
|
|
562
|
-
mul_func = P.Mul()
|
|
563
|
-
fill_func = P.Fill()
|
|
564
|
-
dtype = P.DType()
|
|
565
|
-
|
|
566
|
-
def bprop(x, out, dout):
|
|
567
|
-
temp = mul_func(dout, x)
|
|
568
|
-
shape_x = shape_op(x)
|
|
569
|
-
if F.is_sequence_value_unknown(shape_x):
|
|
570
|
-
fill_value = dyn_fill(dtype(temp), dyn_shape_op(x), 2.0)
|
|
571
|
-
else:
|
|
572
|
-
fill_value = fill_func(dtype(temp), shape_x, 2.0)
|
|
573
|
-
dx = mul_func(fill_value, temp)
|
|
574
|
-
return (dx,)
|
|
575
|
-
|
|
576
|
-
return bprop
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
@bprop_getters.register(P.SquaredDifference)
|
|
580
|
-
def get_bprop_squared_difference(self):
|
|
581
|
-
"""Grad definition for `SquaredDifference` operation."""
|
|
582
|
-
neg = P.Neg()
|
|
583
|
-
|
|
584
|
-
def bprop(x, y, out, dout):
|
|
585
|
-
x_grad = 2 * dout * (x - y)
|
|
586
|
-
bc_x = x_grad
|
|
587
|
-
bc_y = neg(x_grad)
|
|
588
|
-
return binop_grad_common(x, y, bc_x, bc_y)
|
|
589
|
-
|
|
590
|
-
return bprop
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
@bprop_getters.register(P.Xlogy)
|
|
594
|
-
def get_bprop_xlogy(self):
|
|
595
|
-
"""Grad definition for `Xlogy` operation."""
|
|
596
|
-
log_op = P.Xlogy()
|
|
597
|
-
div_op = P.Xdivy()
|
|
598
|
-
|
|
599
|
-
def bprop(x, y, out, dout):
|
|
600
|
-
x_dtype = F.dtype(x)
|
|
601
|
-
not_zero_x = F.cast(F.not_equal(x, F.cast(0.0, x_dtype)), x_dtype)
|
|
602
|
-
bc_x = log_op(not_zero_x, y) * dout
|
|
603
|
-
bc_y = div_op(x, y) * dout
|
|
604
|
-
return binop_grad_common(x, y, bc_x, bc_y)
|
|
605
|
-
|
|
606
|
-
return bprop
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
@bprop_getters.register(P.SquareSumAll)
|
|
610
|
-
def get_bprop_square_sum_all(self):
|
|
611
|
-
"""Grad definition for `SquareSumAll` operation."""
|
|
612
|
-
mul_func = P.Mul()
|
|
613
|
-
fill_func = P.Fill()
|
|
614
|
-
dtype = P.DType()
|
|
615
|
-
|
|
616
|
-
def bprop(x, y, out, dout):
|
|
617
|
-
temp_x = mul_func(dout[0], x)
|
|
618
|
-
temp_y = mul_func(dout[1], y)
|
|
619
|
-
if F.is_sequence_value_unknown(shape_op(x)):
|
|
620
|
-
dx = mul_func(dyn_fill(dtype(temp_x), dyn_shape_op(x), 2.0), temp_x)
|
|
621
|
-
else:
|
|
622
|
-
dx = mul_func(fill_func(dtype(temp_x), shape_op(x), 2.0), temp_x)
|
|
623
|
-
|
|
624
|
-
if F.is_sequence_value_unknown(shape_op(y)):
|
|
625
|
-
dy = mul_func(dyn_fill(dtype(temp_y), dyn_shape_op(y), 2.0), temp_y)
|
|
626
|
-
else:
|
|
627
|
-
dy = mul_func(fill_func(dtype(temp_y), shape_op(y), 2.0), temp_y)
|
|
628
|
-
return (dx, dy)
|
|
629
|
-
|
|
630
|
-
return bprop
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
@bprop_getters.register(P.Sqrt)
|
|
634
|
-
def get_bprop_sqrt(self):
|
|
635
|
-
"""Grad definition for `Sqrt` operation."""
|
|
636
|
-
sqrt_grad = G.SqrtGrad()
|
|
637
|
-
|
|
638
|
-
def bprop(x, out, dout):
|
|
639
|
-
dx = sqrt_grad(out, dout)
|
|
640
|
-
return (dx,)
|
|
641
|
-
|
|
642
|
-
return bprop
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
@bprop_getters.register(G.SqrtGrad)
|
|
646
|
-
def get_bprop_sqrt_grad(self):
|
|
647
|
-
"""Grad definition for `SqrtGrad` operation."""
|
|
648
|
-
|
|
649
|
-
def bprop(y, grad, out, dout):
|
|
650
|
-
gy = dout / y
|
|
651
|
-
dy = -gy * out
|
|
652
|
-
dgrad = 0.5 * gy
|
|
653
|
-
return dy, dgrad
|
|
654
|
-
|
|
655
|
-
return bprop
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
@bprop_getters.register(P.Rsqrt)
|
|
659
|
-
def get_bprop_rsqrt(self):
|
|
660
|
-
"""Grad definition for `Rsqrt` operation."""
|
|
661
|
-
rsqrt_grad = G.RsqrtGrad()
|
|
662
|
-
|
|
663
|
-
def bprop(x, out, dout):
|
|
664
|
-
dx = rsqrt_grad(out, dout)
|
|
665
|
-
return (dx,)
|
|
666
|
-
|
|
667
|
-
return bprop
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
@bprop_getters.register(G.RsqrtGrad)
|
|
671
|
-
def get_bprop_rsqrt_grad(self):
|
|
672
|
-
"""Grad definition for `RsqrtGrad` operation."""
|
|
673
|
-
rsqrt_grad = G.RsqrtGrad()
|
|
674
|
-
|
|
675
|
-
def bprop(y, grad, out, dout):
|
|
676
|
-
dy = -1.5 * grad * y * y * dout
|
|
677
|
-
dgrad = rsqrt_grad(y, dout)
|
|
678
|
-
return dy, dgrad
|
|
679
|
-
|
|
680
|
-
return bprop
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
@bprop_getters.register(P.Reciprocal)
|
|
684
|
-
def get_bprop_reciprocal(self):
|
|
685
|
-
"""Grad definition for `Reciprocal` operation."""
|
|
686
|
-
reciprocal_grad = G.ReciprocalGrad()
|
|
687
|
-
|
|
688
|
-
def bprop(x, out, dout):
|
|
689
|
-
dx = reciprocal_grad(out, dout)
|
|
690
|
-
return (dx,)
|
|
691
|
-
|
|
692
|
-
return bprop
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
@bprop_getters.register(P.Log)
|
|
696
|
-
def get_bprop_log(self):
|
|
697
|
-
"""Grad definition for `Log` operation."""
|
|
698
|
-
reciprocal = P.Reciprocal()
|
|
699
|
-
|
|
700
|
-
def bprop(x, out, dout):
|
|
701
|
-
g = reciprocal(x)
|
|
702
|
-
dx = g * dout
|
|
703
|
-
return (dx,)
|
|
704
|
-
|
|
705
|
-
return bprop
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
@bprop_getters.register(P.Log1p)
|
|
709
|
-
def get_bprop_log1p(self):
|
|
710
|
-
"""Grad definition for `Log1p` operation."""
|
|
711
|
-
reciprocal = P.Reciprocal()
|
|
712
|
-
|
|
713
|
-
def bprop(x, out, dout):
|
|
714
|
-
x_1p = x + 1
|
|
715
|
-
g = reciprocal(x_1p)
|
|
716
|
-
dx = g * dout
|
|
717
|
-
return (dx,)
|
|
718
|
-
|
|
719
|
-
return bprop
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
@bprop_getters.register(P.Erf)
|
|
723
|
-
def get_bprop_erf(self):
|
|
724
|
-
"""Grad definition for `Erf` operation."""
|
|
725
|
-
exp = P.Exp()
|
|
726
|
-
square = P.Square()
|
|
727
|
-
sqrt = P.Sqrt()
|
|
728
|
-
cast = P.Cast()
|
|
729
|
-
dtype = P.DType()
|
|
730
|
-
neg = P.Neg()
|
|
731
|
-
|
|
732
|
-
def bprop(x, out, dout):
|
|
733
|
-
half_root_pi = cast(2 / sqrt(F.scalar_to_tensor(np.pi)), dtype(x))
|
|
734
|
-
x_square = square(x)
|
|
735
|
-
dx = dout * half_root_pi * exp(neg(x_square))
|
|
736
|
-
return (dx,)
|
|
737
|
-
|
|
738
|
-
return bprop
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
@bprop_getters.register(P.Erfc)
|
|
742
|
-
def get_bprop_erfc(self):
|
|
743
|
-
"""Grad definition for `Erfc` operation."""
|
|
744
|
-
exp = P.Exp()
|
|
745
|
-
square = P.Square()
|
|
746
|
-
sqrt = P.Sqrt()
|
|
747
|
-
cast = P.Cast()
|
|
748
|
-
dtype = P.DType()
|
|
749
|
-
neg = P.Neg()
|
|
750
|
-
|
|
751
|
-
def bprop(x, out, dout):
|
|
752
|
-
half_root_pi = cast(2 / sqrt(F.scalar_to_tensor(np.pi)), dtype(x))
|
|
753
|
-
x_square = square(x)
|
|
754
|
-
dx = dout * (neg(half_root_pi) * exp(neg(x_square)))
|
|
755
|
-
return (dx,)
|
|
756
|
-
|
|
757
|
-
return bprop
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
@bprop_getters.register(P.Pow)
|
|
761
|
-
def get_bprop_pow(self):
|
|
762
|
-
"""Grad definition for `Pow` operation."""
|
|
763
|
-
pow_op = P.Pow()
|
|
764
|
-
ln = P.Log()
|
|
765
|
-
|
|
766
|
-
def bprop(x, power, out, dout):
|
|
767
|
-
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
768
|
-
raise TypeError("For 'Pow', gradient not support for complex type currently.")
|
|
769
|
-
bc_dx = power * pow_op(x, power - 1.0) * dout
|
|
770
|
-
shape_x = shape_op(x)
|
|
771
|
-
if F.is_sequence_value_unknown(shape_x):
|
|
772
|
-
x = F.select(x < 0, dyn_fill(F.dtype(x), dyn_shape_op(x), 1), x)
|
|
773
|
-
else:
|
|
774
|
-
x = F.select(x < 0, F.fill(F.dtype(x), F.shape(x), 1), x)
|
|
775
|
-
bc_dpower = out * ln(x) * dout
|
|
776
|
-
return binop_grad_common(x, power, bc_dx, bc_dpower)
|
|
777
|
-
|
|
778
|
-
return bprop
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
@bprop_getters.register(P.Exp)
|
|
782
|
-
def get_bprop_exp(self):
|
|
783
|
-
"""Grad definition for `Exp` operation."""
|
|
784
|
-
exp_ = P.Exp()
|
|
785
|
-
|
|
786
|
-
def bprop(x, out, dout):
|
|
787
|
-
g = exp_(x)
|
|
788
|
-
dx = g * dout
|
|
789
|
-
return (dx,)
|
|
790
|
-
|
|
791
|
-
return bprop
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
@bprop_getters.register(P.Einsum)
|
|
795
|
-
def get_bprop_einsum(self):
|
|
796
|
-
"""Grad definition for `Einsum` operation."""
|
|
797
|
-
grad_func = G.EinsumGrad(self.equation)
|
|
798
|
-
|
|
799
|
-
def bprop(x, out, dout):
|
|
800
|
-
dx = grad_func(x, dout)
|
|
801
|
-
return (dx,)
|
|
802
|
-
|
|
803
|
-
return bprop
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
@bprop_getters.register(P.Expm1)
|
|
807
|
-
def get_bprop_expm1(self):
|
|
808
|
-
"""Grad definition for `Expm1` operation."""
|
|
809
|
-
exp_ = P.Exp()
|
|
810
|
-
|
|
811
|
-
def bprop(x, out, dout):
|
|
812
|
-
g = exp_(x)
|
|
813
|
-
dx = g * dout
|
|
814
|
-
return (dx,)
|
|
815
|
-
|
|
816
|
-
return bprop
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
@bprop_getters.register(P.Minimum)
|
|
820
|
-
def get_bprop_minimum(self):
|
|
821
|
-
"""Grad definition for `Minimum` operation."""
|
|
822
|
-
input_grad = G.MinimumGrad()
|
|
823
|
-
|
|
824
|
-
def bprop(x, y, out, dout):
|
|
825
|
-
dx, dy = input_grad(x, y, dout)
|
|
826
|
-
return dx, dy
|
|
827
|
-
|
|
828
|
-
return bprop
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
@bprop_getters.register(P.Maximum)
|
|
832
|
-
def get_bprop_maximum(self):
|
|
833
|
-
"""Grad definition for `Maximum` operation."""
|
|
834
|
-
input_grad = G.MaximumGrad()
|
|
835
|
-
|
|
836
|
-
def bprop(x, y, out, dout):
|
|
837
|
-
dx, dy = input_grad(x, y, dout)
|
|
838
|
-
return dx, dy
|
|
839
|
-
|
|
840
|
-
return bprop
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
@bprop_getters.register(P.ReduceSum)
|
|
844
|
-
def get_bprop_reducesum(self):
|
|
845
|
-
"""Grad definition for `ReduceSum` operation."""
|
|
846
|
-
|
|
847
|
-
def bprop(x, axis, out, dout):
|
|
848
|
-
dx = _sum_grad(x, axis, dout)
|
|
849
|
-
return dx, zeros_like(axis)
|
|
850
|
-
|
|
851
|
-
return bprop
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
@bprop_getters.register(P.CumSum)
|
|
855
|
-
def get_bprop_cumsum(self):
|
|
856
|
-
"""Grad definition for `CumSum` operation."""
|
|
857
|
-
cumsum = P.CumSum(exclusive=self.exclusive, reverse=not self.reverse)
|
|
858
|
-
|
|
859
|
-
def bprop(x, axis, out, dout):
|
|
860
|
-
return cumsum(dout, axis), zeros_like(axis)
|
|
861
|
-
|
|
862
|
-
return bprop
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
@_primexpr
|
|
866
|
-
def _split_shape_index(input_shape, axis):
|
|
867
|
-
"""Calculate reduce_prod grad transpose indices and perm shape."""
|
|
868
|
-
rank = len(input_shape)
|
|
869
|
-
if isinstance(axis, int):
|
|
870
|
-
axis = tuple([axis])
|
|
871
|
-
reduction_indices = tuple([(i + rank) % rank for i in axis])
|
|
872
|
-
other_indices_list = []
|
|
873
|
-
for i in range(rank):
|
|
874
|
-
if i not in reduction_indices and i not in other_indices_list:
|
|
875
|
-
other_indices_list.append(i)
|
|
876
|
-
other_indices = tuple(other_indices_list)
|
|
877
|
-
reduced_list = [1] + [input_shape[i] for i in reduction_indices]
|
|
878
|
-
other_list = [1] + [input_shape[i] for i in other_indices]
|
|
879
|
-
reduced_num = 1
|
|
880
|
-
for i in reduced_list:
|
|
881
|
-
reduced_num = reduced_num * i
|
|
882
|
-
other_num = 1
|
|
883
|
-
for i in other_list:
|
|
884
|
-
other_num = other_num * i
|
|
885
|
-
perm = reduction_indices + other_indices
|
|
886
|
-
return tuple([reduced_num, other_num]), perm
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
@_primexpr
|
|
890
|
-
def _invert_permutation(perm):
|
|
891
|
-
"""Calculate invert permutation."""
|
|
892
|
-
out = [0] * len(perm)
|
|
893
|
-
for i, value in enumerate(perm):
|
|
894
|
-
out[value] = i
|
|
895
|
-
return tuple(out)
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
def _split_dyn_shape_index(x, axis):
|
|
899
|
-
"""Calculate reduce prod grad invert permutation."""
|
|
900
|
-
input_shape = dyn_shape_op(x)
|
|
901
|
-
rank = dyn_rank(x)
|
|
902
|
-
if not isinstance(axis, Tensor):
|
|
903
|
-
axis = Tensor(axis, dtype=mstype.int64)
|
|
904
|
-
reduction_indices = reshape(axis, (-1,))
|
|
905
|
-
reduction_indices = (reduction_indices + rank) % rank
|
|
906
|
-
reduced = P.Cast()(reduction_indices, mstype.int64)
|
|
907
|
-
|
|
908
|
-
start = Tensor(0, dtype=mstype.int64)
|
|
909
|
-
delta = Tensor(1, dtype=mstype.int64)
|
|
910
|
-
idx = P.Range()(start, rank, delta)
|
|
911
|
-
other, _ = A.ListDiff()(idx, reduced)
|
|
912
|
-
perm = P.Concat()((reduced, other))
|
|
913
|
-
reduced_num = reduce_prod(P.Cast()(P.Gather()(input_shape, reduced, 0), mstype.int64), ())
|
|
914
|
-
other_num = reduce_prod(P.Cast()(P.Gather()(input_shape, other, 0), mstype.int64), ())
|
|
915
|
-
return (reduced_num, other_num), perm
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
@bprop_getters.register(P.ReduceProd)
|
|
919
|
-
def get_bprop_reduceprod(self):
|
|
920
|
-
"""Grad definition for `ReduceProd` operation."""
|
|
921
|
-
transpose = P.Transpose()
|
|
922
|
-
left_cumprod = P.CumProd(exclusive=True)
|
|
923
|
-
right_cumprod = P.CumProd(exclusive=True, reverse=True)
|
|
924
|
-
|
|
925
|
-
def bprop(x, axis, out, dout):
|
|
926
|
-
"""Grad definition for `Product` operation."""
|
|
927
|
-
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
928
|
-
raise TypeError("The 'ReduceProd', gradient not support for complex type currently.")
|
|
929
|
-
# Expand dout to full input shape
|
|
930
|
-
input_shape = shape_op(x)
|
|
931
|
-
if input_shape == ():
|
|
932
|
-
dx = _sum_grad(x, axis, dout)
|
|
933
|
-
return dx, zeros_like(axis)
|
|
934
|
-
|
|
935
|
-
if F.is_sequence_value_unknown(input_shape):
|
|
936
|
-
input_shape = dyn_shape_op(x)
|
|
937
|
-
input_shape = P.Cast()(input_shape, ms.int64)
|
|
938
|
-
output_shape_kept_dims = _dyn_reduced_shape(input_shape, axis, x)
|
|
939
|
-
output_shape_kept_dims = P.Cast()(output_shape_kept_dims, ms.int64)
|
|
940
|
-
else:
|
|
941
|
-
output_shape_kept_dims = reduced_shape(input_shape, axis)
|
|
942
|
-
|
|
943
|
-
dout = reshape(dout, output_shape_kept_dims)
|
|
944
|
-
|
|
945
|
-
# Pack all reduced dimensions into a single one, so we can perform the cumprod ops.
|
|
946
|
-
if F.is_sequence_value_unknown(shape_op(x)):
|
|
947
|
-
pack_shape, perm = _split_dyn_shape_index(x, axis)
|
|
948
|
-
else:
|
|
949
|
-
pack_shape, perm = _split_shape_index(shape_op(x), axis)
|
|
950
|
-
|
|
951
|
-
permuted = transpose(x, perm)
|
|
952
|
-
permuted_shape = shape_op(permuted)
|
|
953
|
-
if F.is_sequence_value_unknown(permuted_shape):
|
|
954
|
-
permuted_shape = dyn_shape_op(permuted)
|
|
955
|
-
pack_shape = create_tensor_by_element(pack_shape)
|
|
956
|
-
reshaped = reshape(permuted, pack_shape)
|
|
957
|
-
|
|
958
|
-
# Calculate product, leaving out the current entry
|
|
959
|
-
left = left_cumprod(reshaped, 0)
|
|
960
|
-
right = right_cumprod(reshaped, 0)
|
|
961
|
-
y = reshape(left * right, permuted_shape)
|
|
962
|
-
|
|
963
|
-
# Invert the transpose and reshape operations.
|
|
964
|
-
# Make sure to set the statically known shape information through a reshape.
|
|
965
|
-
if F.is_sequence_value_unknown(shape_op(permuted)):
|
|
966
|
-
dout = DynamicBroadcastTo()(dout, input_shape)
|
|
967
|
-
out = transpose(y, dyn_invert_permutation(perm)) * dout
|
|
968
|
-
else:
|
|
969
|
-
tile_scaling = tuple_div(input_shape, output_shape_kept_dims)
|
|
970
|
-
grad = tile(dout, tile_scaling)
|
|
971
|
-
out = transpose(y, _invert_permutation(perm)) * grad
|
|
972
|
-
|
|
973
|
-
dx = reshape(out, input_shape)
|
|
974
|
-
return dx, zeros_like(axis)
|
|
975
|
-
|
|
976
|
-
return bprop
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
@bprop_getters.register(P.CumProd)
|
|
980
|
-
def get_bprop_cumprod(self):
|
|
981
|
-
"""Grad definition for `CumProd` operation."""
|
|
982
|
-
cumprod = P.CumProd(exclusive=self.exclusive, reverse=self.reverse)
|
|
983
|
-
cumsum = P.CumSum(exclusive=self.exclusive, reverse=not self.reverse)
|
|
984
|
-
|
|
985
|
-
def bprop(x, axis, out, dout):
|
|
986
|
-
"""Grad definition for `Product` operation."""
|
|
987
|
-
# This will fails when x contains 0
|
|
988
|
-
prod = cumprod(x, axis)
|
|
989
|
-
out = cumsum(prod * dout, axis)
|
|
990
|
-
return out / x, zeros_like(axis)
|
|
991
|
-
|
|
992
|
-
return bprop
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
@bprop_getters.register(P.ReduceAll)
|
|
996
|
-
def get_bprop_reduceall(self):
|
|
997
|
-
"""Grad definition for `ReduceAll` operation."""
|
|
998
|
-
|
|
999
|
-
def bprop(x, axis, out, dout):
|
|
1000
|
-
return zeros_like(x), zeros_like(axis)
|
|
1001
|
-
|
|
1002
|
-
return bprop
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
@bprop_getters.register(P.ReduceAny)
|
|
1006
|
-
def get_bprop_reduceany(self):
|
|
1007
|
-
"""Grad definition for `ReduceAny` operation."""
|
|
1008
|
-
|
|
1009
|
-
def bprop(x, axis, out, dout):
|
|
1010
|
-
return zeros_like(x), zeros_like(axis)
|
|
1011
|
-
|
|
1012
|
-
return bprop
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
@bprop_getters.register(P.ReduceMax)
|
|
1016
|
-
def get_bprop_reducemax(self):
|
|
1017
|
-
"""Grad definition for `Max` operation."""
|
|
1018
|
-
|
|
1019
|
-
def bprop(x, axis, out, dout):
|
|
1020
|
-
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
1021
|
-
raise TypeError("The 'ReduceMax', gradient not support for complex type currently.")
|
|
1022
|
-
dx = _min_or_max_grad(x, axis, out, dout)
|
|
1023
|
-
return (dx, zeros_like(axis))
|
|
1024
|
-
|
|
1025
|
-
return bprop
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
@bprop_getters.register(P.ArgMaxWithValue)
|
|
1029
|
-
def get_bprop_argmaxwithvalue(self):
|
|
1030
|
-
"""Grad definition for `ArgMaxWithValue` operation."""
|
|
1031
|
-
axis = self.axis
|
|
1032
|
-
keep_dims = self.keep_dims
|
|
1033
|
-
op = P.ArgMaxWithValue(axis)
|
|
1034
|
-
|
|
1035
|
-
def bprop(x, out, dout):
|
|
1036
|
-
dx = _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout)
|
|
1037
|
-
return (dx,)
|
|
1038
|
-
|
|
1039
|
-
return bprop
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
@bprop_getters.register(P.ReduceMin)
|
|
1043
|
-
def get_bprop_reducemin(self):
|
|
1044
|
-
"""Grad definition for `ReduceMin` operation."""
|
|
1045
|
-
|
|
1046
|
-
def bprop(x, axis, out, dout):
|
|
1047
|
-
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
1048
|
-
raise TypeError("The 'ReduceMin', gradient not support for complex type currently.")
|
|
1049
|
-
dx = _min_or_max_grad(x, axis, out, dout)
|
|
1050
|
-
return (dx, zeros_like(axis))
|
|
1051
|
-
|
|
1052
|
-
return bprop
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
@bprop_getters.register(P.ArgMinWithValue)
|
|
1056
|
-
def get_bprop_argminwithvalue(self):
|
|
1057
|
-
"""Generate bprop for ArgMinWithValue"""
|
|
1058
|
-
axis = self.axis
|
|
1059
|
-
keep_dims = self.keep_dims
|
|
1060
|
-
op = P.ArgMinWithValue(axis)
|
|
1061
|
-
|
|
1062
|
-
def bprop(x, out, dout):
|
|
1063
|
-
dx = _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout)
|
|
1064
|
-
return (dx,)
|
|
1065
|
-
|
|
1066
|
-
return bprop
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
@bprop_getters.register(P.ReduceMean)
|
|
1070
|
-
def get_bprop_reduce_mean(self):
|
|
1071
|
-
"""Grad definition for `ReduceMean` operation."""
|
|
1072
|
-
div_op = P.RealDiv()
|
|
1073
|
-
cast = P.Cast()
|
|
1074
|
-
dtype = P.DType()
|
|
1075
|
-
|
|
1076
|
-
def bprop(x, axis, out, dout):
|
|
1077
|
-
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
1078
|
-
raise TypeError("The 'ReduceMean', gradient not support for complex type currently.")
|
|
1079
|
-
grad = _sum_grad(x, axis, dout)
|
|
1080
|
-
shape_x = shape_op(x)
|
|
1081
|
-
shape_out = shape_op(out)
|
|
1082
|
-
if F.is_sequence_value_unknown(shape_x) or F.is_sequence_value_unknown(shape_out):
|
|
1083
|
-
shape_x = dyn_shape_op(x)
|
|
1084
|
-
shape_out = dyn_shape_op(out)
|
|
1085
|
-
div_shape = reduce_prod(cast(shape_x, mstype.float32), ()) /\
|
|
1086
|
-
reduce_prod(cast(shape_out, mstype.float32), ())
|
|
1087
|
-
dx = div_op(grad, cast(div_shape, dtype(grad)))
|
|
1088
|
-
else:
|
|
1089
|
-
div_shape = F.shape_mul(shape_x) / F.shape_mul(shape_out)
|
|
1090
|
-
dx = div_op(grad, cast(F.scalar_to_tensor(div_shape), dtype(grad)))
|
|
1091
|
-
return dx, zeros_like(axis)
|
|
1092
|
-
|
|
1093
|
-
return bprop
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
@bprop_getters.register(P.IsFinite)
|
|
1097
|
-
def get_bprop_isfinite(self):
|
|
1098
|
-
"""Grad definition for `IsFinite` operation."""
|
|
1099
|
-
|
|
1100
|
-
def bprop(x, out, dout):
|
|
1101
|
-
return (zeros_like(x),)
|
|
1102
|
-
|
|
1103
|
-
return bprop
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
@bprop_getters.register(P.IsNan)
|
|
1107
|
-
def get_bprop_isnan(self):
|
|
1108
|
-
"""Grad definition for `IsNan` operation."""
|
|
1109
|
-
|
|
1110
|
-
def bprop(x, out, dout):
|
|
1111
|
-
return (zeros_like(x),)
|
|
1112
|
-
|
|
1113
|
-
return bprop
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
@bprop_getters.register(P.IsInf)
|
|
1117
|
-
def get_bprop_isinf(self):
|
|
1118
|
-
"""Grad definition for `IsInf` operation."""
|
|
1119
|
-
|
|
1120
|
-
def bprop(x, out, dout):
|
|
1121
|
-
return (zeros_like(x),)
|
|
1122
|
-
|
|
1123
|
-
return bprop
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
@bprop_getters.register(P.Equal)
|
|
1127
|
-
def get_bprop_equal(self):
|
|
1128
|
-
"""Grad definition for `Equal` operation."""
|
|
1129
|
-
|
|
1130
|
-
def bprop(x, y, out, dout):
|
|
1131
|
-
return zeros_like(x), zeros_like(y)
|
|
1132
|
-
|
|
1133
|
-
return bprop
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
@bprop_getters.register(P.NotEqual)
|
|
1137
|
-
def get_bprop_not_equal(self):
|
|
1138
|
-
"""Grad definition for `NotEqual` operation."""
|
|
1139
|
-
|
|
1140
|
-
def bprop(x, y, out, dout):
|
|
1141
|
-
return zeros_like(x), zeros_like(y)
|
|
1142
|
-
|
|
1143
|
-
return bprop
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
@bprop_getters.register(P.ApproximateEqual)
|
|
1147
|
-
def get_bprop_approximate_equal(self):
|
|
1148
|
-
"""Grad definition for `ApproximateEqual` operation."""
|
|
1149
|
-
|
|
1150
|
-
def bprop(x, y, out, dout):
|
|
1151
|
-
return zeros_like(x), zeros_like(y)
|
|
1152
|
-
|
|
1153
|
-
return bprop
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
@bprop_getters.register(P.Greater)
|
|
1157
|
-
def get_bprop_greater(self):
|
|
1158
|
-
"""Grad definition for `Greater` operation."""
|
|
1159
|
-
|
|
1160
|
-
def bprop(x, y, out, dout):
|
|
1161
|
-
return zeros_like(x), zeros_like(y)
|
|
1162
|
-
|
|
1163
|
-
return bprop
|
|
1164
|
-
|
|
1165
|
-
|
|
1166
|
-
@bprop_getters.register(P.GreaterEqual)
|
|
1167
|
-
def get_bprop_greater_equal(self):
|
|
1168
|
-
"""Grad definition for `GreaterEqual` operation."""
|
|
1169
|
-
|
|
1170
|
-
def bprop(x, y, out, dout):
|
|
1171
|
-
return zeros_like(x), zeros_like(y)
|
|
1172
|
-
|
|
1173
|
-
return bprop
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
@bprop_getters.register(P.Less)
|
|
1177
|
-
def get_bprop_less(self):
|
|
1178
|
-
"""Grad definition for `Less` operation."""
|
|
1179
|
-
|
|
1180
|
-
def bprop(x, y, out, dout):
|
|
1181
|
-
return zeros_like(x), zeros_like(y)
|
|
1182
|
-
|
|
1183
|
-
return bprop
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
@bprop_getters.register(P.LessEqual)
|
|
1187
|
-
def get_bprop_less_equal(self):
|
|
1188
|
-
"""Grad definition for `LessEqual` operation."""
|
|
1189
|
-
|
|
1190
|
-
def bprop(x, y, out, dout):
|
|
1191
|
-
return zeros_like(x), zeros_like(y)
|
|
1192
|
-
|
|
1193
|
-
return bprop
|
|
1194
|
-
|
|
1195
|
-
|
|
1196
|
-
@bprop_getters.register(P.LogicalNot)
|
|
1197
|
-
def get_bprop_logical_not(self):
|
|
1198
|
-
"""Grad definition for `LogicalNot` operation."""
|
|
1199
|
-
|
|
1200
|
-
def bprop(x, out, dout):
|
|
1201
|
-
return (zeros_like(x),)
|
|
1202
|
-
|
|
1203
|
-
return bprop
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
@bprop_getters.register(P.LogicalAnd)
|
|
1207
|
-
def get_bprop_logical_and(self):
|
|
1208
|
-
"""Grad definition for `LogicalAnd` operation."""
|
|
1209
|
-
|
|
1210
|
-
def bprop(x, y, out, dout):
|
|
1211
|
-
return zeros_like(x), zeros_like(y)
|
|
1212
|
-
|
|
1213
|
-
return bprop
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
@bprop_getters.register(P.NPUAllocFloatStatus)
|
|
1217
|
-
def get_bprop_npu_alloc_float_status(self):
|
|
1218
|
-
"""Grad definition for `NPUAllocFloatStatus` operation."""
|
|
1219
|
-
|
|
1220
|
-
def bprop(out, dout):
|
|
1221
|
-
return ()
|
|
1222
|
-
|
|
1223
|
-
return bprop
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
@bprop_getters.register(P.NPUGetFloatStatus)
|
|
1227
|
-
def get_bprop_npu_get_float_status(self):
|
|
1228
|
-
"""Grad definition for `NPUGetFloatStatus` operation."""
|
|
1229
|
-
|
|
1230
|
-
def bprop(x, out, dout):
|
|
1231
|
-
return (zeros_like(x),)
|
|
1232
|
-
|
|
1233
|
-
return bprop
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
@bprop_getters.register(P.NPUClearFloatStatus)
|
|
1237
|
-
def get_bprop_npu_clear_float_status(self):
|
|
1238
|
-
"""Grad definition for `NPUClearFloatStatus` operation."""
|
|
1239
|
-
|
|
1240
|
-
def bprop(x, out, dout):
|
|
1241
|
-
return (zeros_like(x),)
|
|
1242
|
-
|
|
1243
|
-
return bprop
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
@bprop_getters.register(P.AssignAdd)
|
|
1247
|
-
def get_bprop_assign_add(self):
|
|
1248
|
-
"""Grad definition for `AssignAdd` operation."""
|
|
1249
|
-
|
|
1250
|
-
def bprop(x, y, out, dout):
|
|
1251
|
-
return zeros_like(x), zeros_like(y)
|
|
1252
|
-
|
|
1253
|
-
return bprop
|
|
1254
|
-
|
|
1255
|
-
|
|
1256
|
-
@bprop_getters.register(P.AssignSub)
|
|
1257
|
-
def get_bprop_assign_sub(self):
|
|
1258
|
-
"""Grad definition for `AssignSub` operation."""
|
|
1259
|
-
|
|
1260
|
-
def bprop(x, y, out, dout):
|
|
1261
|
-
return zeros_like(x), zeros_like(y)
|
|
1262
|
-
|
|
1263
|
-
return bprop
|
|
1264
|
-
|
|
1265
|
-
|
|
1266
|
-
@bprop_getters.register(P.Sin)
|
|
1267
|
-
def get_bprop_sin(self):
|
|
1268
|
-
"""Grad definition for `Sin` operation."""
|
|
1269
|
-
cos = P.Cos()
|
|
1270
|
-
|
|
1271
|
-
def bprop(x, out, dout):
|
|
1272
|
-
dx = dout * cos(x)
|
|
1273
|
-
return (dx,)
|
|
1274
|
-
|
|
1275
|
-
return bprop
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
@bprop_getters.register(P.Asin)
|
|
1279
|
-
def get_bprop_asin(self):
|
|
1280
|
-
"""Grad definition for `Asin` operation."""
|
|
1281
|
-
input_grad = G.AsinGrad()
|
|
1282
|
-
|
|
1283
|
-
def bprop(x, out, dout):
|
|
1284
|
-
dx = input_grad(x, dout)
|
|
1285
|
-
return (dx,)
|
|
1286
|
-
|
|
1287
|
-
return bprop
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
@bprop_getters.register(G.AsinGrad)
|
|
1291
|
-
def get_bprop_asin_grad(self):
|
|
1292
|
-
"""Grad definition for `AsinGrad` operation."""
|
|
1293
|
-
input_grad = G.AsinGrad()
|
|
1294
|
-
p_pow = P.Pow()
|
|
1295
|
-
|
|
1296
|
-
def bprop(x, grad, out, dout):
|
|
1297
|
-
d2x = dout * grad * x * p_pow((1 - x * x), - 1.5)
|
|
1298
|
-
ddy = input_grad(x, dout)
|
|
1299
|
-
return (d2x, ddy)
|
|
1300
|
-
|
|
1301
|
-
return bprop
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
@bprop_getters.register(P.Asinh)
|
|
1305
|
-
def get_bprop_asinh(self):
|
|
1306
|
-
"""Grad definition for `Asinh` operation."""
|
|
1307
|
-
input_grad = G.AsinhGrad()
|
|
1308
|
-
|
|
1309
|
-
def bprop(x, out, dout):
|
|
1310
|
-
dx = input_grad(out, dout)
|
|
1311
|
-
return (dx,)
|
|
1312
|
-
|
|
1313
|
-
return bprop
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
@bprop_getters.register(G.AsinhGrad)
|
|
1317
|
-
def get_bprop_asinh_grad(self):
|
|
1318
|
-
"""Grad definition for `AsinhGrad` operation."""
|
|
1319
|
-
input_grad = G.AsinhGrad()
|
|
1320
|
-
tanh = P.Tanh()
|
|
1321
|
-
|
|
1322
|
-
def bprop(y, grad, out, dout):
|
|
1323
|
-
dy = dout * out * -1.0 * tanh(y)
|
|
1324
|
-
dgrad = input_grad(y, dout)
|
|
1325
|
-
return dy, dgrad
|
|
1326
|
-
|
|
1327
|
-
return bprop
|
|
1328
|
-
|
|
1329
|
-
|
|
1330
|
-
@bprop_getters.register(P.Sinh)
|
|
1331
|
-
def get_bprop_sinh(self):
|
|
1332
|
-
"""Grad definition for `Sinh` operation."""
|
|
1333
|
-
cosh = P.Cosh()
|
|
1334
|
-
|
|
1335
|
-
def bprop(x, out, dout):
|
|
1336
|
-
dx = cosh(x) * dout
|
|
1337
|
-
return (dx,)
|
|
1338
|
-
|
|
1339
|
-
return bprop
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
@bprop_getters.register(P.Cos)
|
|
1343
|
-
def get_bprop_cos(self):
|
|
1344
|
-
"""Grad definition for `Cos` operation."""
|
|
1345
|
-
sin = P.Sin()
|
|
1346
|
-
neg = P.Neg()
|
|
1347
|
-
|
|
1348
|
-
def bprop(x, out, dout):
|
|
1349
|
-
dx = dout * neg(sin(x))
|
|
1350
|
-
return (dx,)
|
|
1351
|
-
|
|
1352
|
-
return bprop
|
|
1353
|
-
|
|
1354
|
-
|
|
1355
|
-
@bprop_getters.register(P.ACos)
|
|
1356
|
-
def get_bprop_acos(self):
|
|
1357
|
-
"""Grad definition for `ACos` operation."""
|
|
1358
|
-
input_grad = G.ACosGrad()
|
|
1359
|
-
|
|
1360
|
-
def bprop(x, out, dout):
|
|
1361
|
-
dx = input_grad(x, dout)
|
|
1362
|
-
return (dx,)
|
|
1363
|
-
|
|
1364
|
-
return bprop
|
|
1365
|
-
|
|
1366
|
-
|
|
1367
|
-
@bprop_getters.register(G.ACosGrad)
|
|
1368
|
-
def get_bprop_acos_grad(self):
|
|
1369
|
-
"""Grad definition for `ACosGrad` operation."""
|
|
1370
|
-
input_grad = G.ACosGrad()
|
|
1371
|
-
p_pow = P.Pow()
|
|
1372
|
-
|
|
1373
|
-
def bprop(x, grad, out, dout):
|
|
1374
|
-
d2x = -dout * grad * x * p_pow((1 - x * x), - 1.5)
|
|
1375
|
-
ddy = input_grad(x, dout)
|
|
1376
|
-
return (d2x, ddy)
|
|
1377
|
-
|
|
1378
|
-
return bprop
|
|
1379
|
-
|
|
1380
|
-
|
|
1381
|
-
@bprop_getters.register(P.Acosh)
|
|
1382
|
-
def get_bprop_acosh(self):
|
|
1383
|
-
"""Grad definition for `Acosh` operation."""
|
|
1384
|
-
input_grad = G.AcoshGrad()
|
|
1385
|
-
|
|
1386
|
-
def bprop(x, out, dout):
|
|
1387
|
-
dx = input_grad(out, dout)
|
|
1388
|
-
return (dx,)
|
|
1389
|
-
|
|
1390
|
-
return bprop
|
|
1391
|
-
|
|
1392
|
-
|
|
1393
|
-
@bprop_getters.register(G.AcoshGrad)
|
|
1394
|
-
def get_bprop_acosh_grad(self):
|
|
1395
|
-
"""Grad definition for `AcoshGrad` operation."""
|
|
1396
|
-
input_grad = G.AcoshGrad()
|
|
1397
|
-
tanh = P.Tanh()
|
|
1398
|
-
|
|
1399
|
-
def bprop(y, grad, out, dout):
|
|
1400
|
-
dy = dout * out * -1.0 / tanh(y)
|
|
1401
|
-
dgrad = input_grad(y, dout)
|
|
1402
|
-
return dy, dgrad
|
|
1403
|
-
|
|
1404
|
-
return bprop
|
|
1405
|
-
|
|
1406
|
-
|
|
1407
|
-
@bprop_getters.register(P.Cosh)
|
|
1408
|
-
def get_bprop_cosh(self):
|
|
1409
|
-
"""Grad definition for `Cosh` operation."""
|
|
1410
|
-
sinh = P.Sinh()
|
|
1411
|
-
|
|
1412
|
-
def bprop(x, out, dout):
|
|
1413
|
-
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
1414
|
-
raise TypeError("The 'Cosh', gradient not support for complex type currently.")
|
|
1415
|
-
|
|
1416
|
-
dx = sinh(x) * dout
|
|
1417
|
-
return (dx,)
|
|
1418
|
-
|
|
1419
|
-
return bprop
|
|
1420
|
-
|
|
1421
|
-
|
|
1422
|
-
@bprop_getters.register(P.Abs)
|
|
1423
|
-
def get_bprop_abs(self):
|
|
1424
|
-
"""Grad definition for `Abs` operation."""
|
|
1425
|
-
abs_grad = G.AbsGrad()
|
|
1426
|
-
|
|
1427
|
-
def bprop(x, out, dout):
|
|
1428
|
-
dx = abs_grad(x, dout)
|
|
1429
|
-
return (dx,)
|
|
1430
|
-
|
|
1431
|
-
return bprop
|
|
1432
|
-
|
|
1433
|
-
|
|
1434
|
-
@bprop_getters.register(P.Conj)
|
|
1435
|
-
def get_bprop_conj(self):
|
|
1436
|
-
"""Grad definition for `Conj` operation."""
|
|
1437
|
-
conj = P.Conj()
|
|
1438
|
-
|
|
1439
|
-
def bprop(x, out, dout):
|
|
1440
|
-
dx = conj(dout)
|
|
1441
|
-
return (dx,)
|
|
1442
|
-
|
|
1443
|
-
return bprop
|
|
1444
|
-
|
|
1445
|
-
|
|
1446
|
-
@bprop_getters.register(P.AccumulateNV2)
|
|
1447
|
-
def get_bprop_scalar_accumulatenv2(self):
|
|
1448
|
-
"""Generate bprop for AccumulateNV2"""
|
|
1449
|
-
|
|
1450
|
-
def bprop(x, out, dout):
|
|
1451
|
-
dx = ()
|
|
1452
|
-
for _ in range(len(x)):
|
|
1453
|
-
dx = dx + (dout,)
|
|
1454
|
-
return (dx,)
|
|
1455
|
-
|
|
1456
|
-
return bprop
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
@bprop_getters.register(P.AddN)
|
|
1460
|
-
def get_bprop_scalar_addn(self):
|
|
1461
|
-
"""Generate bprop for AddN"""
|
|
1462
|
-
|
|
1463
|
-
def bprop(x, out, dout):
|
|
1464
|
-
if is_sub_class(F.typeof(x), ms.list_):
|
|
1465
|
-
dx = []
|
|
1466
|
-
for _ in range(len(x)):
|
|
1467
|
-
dx.append(dout)
|
|
1468
|
-
return (dx,)
|
|
1469
|
-
|
|
1470
|
-
dx = ()
|
|
1471
|
-
for _ in range(len(x)):
|
|
1472
|
-
dx = dx + (dout,)
|
|
1473
|
-
return (dx,)
|
|
1474
|
-
|
|
1475
|
-
return bprop
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
@bprop_getters.register(P.Sign)
|
|
1479
|
-
def get_bprop_sign(self):
|
|
1480
|
-
"""Generate bprop for Sign"""
|
|
1481
|
-
|
|
1482
|
-
def bprop(x, out, dout):
|
|
1483
|
-
return (zeros_like(x),)
|
|
1484
|
-
|
|
1485
|
-
return bprop
|
|
1486
|
-
|
|
1487
|
-
|
|
1488
|
-
@bprop_getters.register(P.Round)
|
|
1489
|
-
def get_bprop_round(self):
|
|
1490
|
-
"""Generate bprop for Round"""
|
|
1491
|
-
|
|
1492
|
-
def bprop(x, out, dout):
|
|
1493
|
-
return (zeros_like(x),)
|
|
1494
|
-
|
|
1495
|
-
return bprop
|
|
1496
|
-
|
|
1497
|
-
|
|
1498
|
-
@bprop_getters.register(P.Atan2)
|
|
1499
|
-
def get_bprop_atan2(self):
|
|
1500
|
-
"""Generate bprop for Atan2"""
|
|
1501
|
-
|
|
1502
|
-
square = P.Square()
|
|
1503
|
-
|
|
1504
|
-
def bprop(x, y, out, dout):
|
|
1505
|
-
tmp = dout / (square(x) + square(y))
|
|
1506
|
-
bc_dx = tmp * y
|
|
1507
|
-
bc_dy = tmp * (-x)
|
|
1508
|
-
return binop_grad_common(x, y, bc_dx, bc_dy)
|
|
1509
|
-
|
|
1510
|
-
return bprop
|
|
1511
|
-
|
|
1512
|
-
|
|
1513
|
-
@bprop_getters.register(P.BesselI0e)
|
|
1514
|
-
def get_bprop_bessel_i0e(self):
|
|
1515
|
-
"""Generate bprop for BesselI0e"""
|
|
1516
|
-
sign = P.Sign()
|
|
1517
|
-
bessel_i1e = P.BesselI1e()
|
|
1518
|
-
|
|
1519
|
-
def bprop(x, out, dout):
|
|
1520
|
-
dx = dout * (bessel_i1e(x) - sign(x) * out)
|
|
1521
|
-
return (dx,)
|
|
1522
|
-
|
|
1523
|
-
return bprop
|
|
1524
|
-
|
|
1525
|
-
|
|
1526
|
-
@bprop_getters.register(P.Atan)
|
|
1527
|
-
def get_bprop_atan(self):
|
|
1528
|
-
"""Grad definition for `Atan` operation."""
|
|
1529
|
-
input_grad = G.AtanGrad()
|
|
1530
|
-
|
|
1531
|
-
def bprop(x, out, dout):
|
|
1532
|
-
dx = input_grad(x, dout)
|
|
1533
|
-
return (dx,)
|
|
1534
|
-
|
|
1535
|
-
return bprop
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
@bprop_getters.register(G.AtanGrad)
|
|
1539
|
-
def get_bprop_atan_grad(self):
|
|
1540
|
-
"""Grad definition for `AtanGrad` operation."""
|
|
1541
|
-
input_grad = G.AtanGrad()
|
|
1542
|
-
|
|
1543
|
-
def bprop(x, grad, out, dout):
|
|
1544
|
-
dgrad = input_grad(x, dout)
|
|
1545
|
-
dx = out * dgrad * -2.0 * x
|
|
1546
|
-
return dx, dgrad
|
|
1547
|
-
|
|
1548
|
-
return bprop
|
|
1549
|
-
|
|
1550
|
-
|
|
1551
|
-
@bprop_getters.register(P.Tan)
|
|
1552
|
-
def get_bprop_tan(self):
|
|
1553
|
-
"""Grad definition for `Tan` operation."""
|
|
1554
|
-
reciprocal = P.Reciprocal()
|
|
1555
|
-
square = P.Square()
|
|
1556
|
-
cos = P.Cos()
|
|
1557
|
-
|
|
1558
|
-
def bprop(x, out, dout):
|
|
1559
|
-
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
1560
|
-
raise TypeError("For 'Tan', gradient not support for complex type currently.")
|
|
1561
|
-
|
|
1562
|
-
cosx = cos(x)
|
|
1563
|
-
secx2 = square(reciprocal(cosx))
|
|
1564
|
-
dx = secx2 * dout
|
|
1565
|
-
return (dx,)
|
|
1566
|
-
|
|
1567
|
-
return bprop
|
|
1568
|
-
|
|
1569
|
-
|
|
1570
|
-
@bprop_getters.register(P.BesselI1e)
|
|
1571
|
-
def get_bprop_bessel_i1e(self):
|
|
1572
|
-
"""Generate bprop for BesselI1e"""
|
|
1573
|
-
|
|
1574
|
-
sign = P.Sign()
|
|
1575
|
-
bessel_i0e = P.BesselI0e()
|
|
1576
|
-
less = P.Less()
|
|
1577
|
-
select = P.Select()
|
|
1578
|
-
reciprocal = P.Reciprocal()
|
|
1579
|
-
cast = P.Cast()
|
|
1580
|
-
dtype = P.DType()
|
|
1581
|
-
abs_ops = P.Abs()
|
|
1582
|
-
|
|
1583
|
-
def bprop(x, out, dout):
|
|
1584
|
-
zeros = zeros_like(x)
|
|
1585
|
-
np_eps = const_utils.get_np_eps(dtype(x))
|
|
1586
|
-
eps = cast(np_eps, dtype(x))
|
|
1587
|
-
x_is_valid = less(eps, abs_ops(x))
|
|
1588
|
-
x_safe = select(x_is_valid, x, eps + zeros)
|
|
1589
|
-
tmp = bessel_i0e(x_safe) - out * (sign(x_safe) + reciprocal(x_safe))
|
|
1590
|
-
dx = select(x_is_valid, tmp, cast(0.5, dtype(x)) + zeros) * dout
|
|
1591
|
-
return (dx,)
|
|
1592
|
-
|
|
1593
|
-
return bprop
|
|
1594
|
-
|
|
1595
|
-
|
|
1596
|
-
@bprop_getters.register(P.Atanh)
|
|
1597
|
-
def get_bprop_atanh(self):
|
|
1598
|
-
"""Grad definition for `Atanh` operation."""
|
|
1599
|
-
power = P.Pow()
|
|
1600
|
-
div = P.Div()
|
|
1601
|
-
|
|
1602
|
-
def bprop(x, out, dout):
|
|
1603
|
-
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
1604
|
-
raise TypeError("For 'Atanh', gradient not support for complex type currently.")
|
|
1605
|
-
|
|
1606
|
-
tmp = 1 - power(x, 2)
|
|
1607
|
-
dx = div(1, tmp) * dout
|
|
1608
|
-
return (dx,)
|
|
1609
|
-
|
|
1610
|
-
return bprop
|
|
1611
|
-
|
|
1612
|
-
|
|
1613
|
-
@bprop_getters.register(P.Inv)
|
|
1614
|
-
def get_bprop_inv(self):
|
|
1615
|
-
"""Grad definition for 'Inv' operation"""
|
|
1616
|
-
inv_grad = G.InvGrad()
|
|
1617
|
-
|
|
1618
|
-
def bprop(x, out, dout):
|
|
1619
|
-
dx = inv_grad(out, dout)
|
|
1620
|
-
return (dx,)
|
|
1621
|
-
|
|
1622
|
-
return bprop
|
|
1623
|
-
|
|
1624
|
-
|
|
1625
|
-
@bprop_getters.register(P.LinSpace)
|
|
1626
|
-
def get_bprop_lin_space(self):
|
|
1627
|
-
"""Grad definition for `LinSpace` operation."""
|
|
1628
|
-
|
|
1629
|
-
def bprop(start, stop, num, out, dout):
|
|
1630
|
-
return zeros_like(start), zeros_like(stop), zeros_like(num)
|
|
1631
|
-
|
|
1632
|
-
return bprop
|
|
1633
|
-
|
|
1634
|
-
|
|
1635
|
-
@bprop_getters.register(P.IndexAdd)
|
|
1636
|
-
def get_bprop_index_add(self):
|
|
1637
|
-
"""Generate bprop for IndexAdd"""
|
|
1638
|
-
gather = P.Gather()
|
|
1639
|
-
_axis = self.axis
|
|
1640
|
-
|
|
1641
|
-
def bprop(input_x, indices, input_y, out, dout):
|
|
1642
|
-
return dout, zeros_like(indices), gather(dout, indices, _axis)
|
|
1643
|
-
|
|
1644
|
-
return bprop
|
|
1645
|
-
|
|
1646
|
-
|
|
1647
|
-
@bprop_getters.register(P.InplaceUpdate)
|
|
1648
|
-
def get_bprop_inplace_update(self):
|
|
1649
|
-
"""Grad definition for `InplaceUpdate` operation."""
|
|
1650
|
-
|
|
1651
|
-
def bprop(x, v, out, dout):
|
|
1652
|
-
return zeros_like(x), zeros_like(v)
|
|
1653
|
-
|
|
1654
|
-
return bprop
|
|
1655
|
-
|
|
1656
|
-
|
|
1657
|
-
@bprop_getters.register(P.InplaceUpdateV2)
|
|
1658
|
-
def get_bprop_inplace_update_v2(self):
|
|
1659
|
-
"""Grad definition for `InplaceUpdateV2` operation."""
|
|
1660
|
-
|
|
1661
|
-
def bprop(x, indices, v, out, dout):
|
|
1662
|
-
return zeros_like(x), zeros_like(indices), zeros_like(v)
|
|
1663
|
-
|
|
1664
|
-
return bprop
|
|
1665
|
-
|
|
1666
|
-
|
|
1667
|
-
@bprop_getters.register(P.InplaceSub)
|
|
1668
|
-
def get_bprop_inplace_sub(self):
|
|
1669
|
-
"""Grad definition for `InplaceSub` operation."""
|
|
1670
|
-
|
|
1671
|
-
def bprop(x, input_v, out, dout):
|
|
1672
|
-
return zeros_like(x), zeros_like(input_v)
|
|
1673
|
-
|
|
1674
|
-
return bprop
|
|
1675
|
-
|
|
1676
|
-
|
|
1677
|
-
@bprop_getters.register(P.InplaceAdd)
|
|
1678
|
-
def get_bprop_inplace_add(self):
|
|
1679
|
-
"""Grad definition for `InplaceAdd` operation."""
|
|
1680
|
-
|
|
1681
|
-
def bprop(x, input_v, out, dout):
|
|
1682
|
-
return zeros_like(x), zeros_like(input_v)
|
|
1683
|
-
|
|
1684
|
-
return bprop
|