mindspore 2.0.0rc1__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu10.1/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +647 -818
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
|
@@ -1,89 +0,0 @@
|
|
|
1
|
-
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
|
|
16
|
-
"""Generate bprop for other ops"""
|
|
17
|
-
|
|
18
|
-
from mindspore.ops import operations as P
|
|
19
|
-
from mindspore.ops.operations import _grad_ops as G
|
|
20
|
-
from mindspore.ops.operations import _inner_ops as inner
|
|
21
|
-
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
22
|
-
from mindspore.ops._grad.grad_base import bprop_getters
|
|
23
|
-
|
|
24
|
-
# Unused parameters are placeholders.
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
@bprop_getters.register(P.Assign)
|
|
28
|
-
def get_bprop_assign(self):
|
|
29
|
-
"""Generate bprop for Assign"""
|
|
30
|
-
|
|
31
|
-
def bprop(x, y, out, dout):
|
|
32
|
-
return (dout, zeros_like(y))
|
|
33
|
-
return bprop
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
@bprop_getters.register(P.InvertPermutation)
|
|
37
|
-
def get_bprop_invert_permutation(self):
|
|
38
|
-
"""Generate bprop for InvertPermutation"""
|
|
39
|
-
|
|
40
|
-
def bprop(x, out, dout):
|
|
41
|
-
return (zeros_like(x),)
|
|
42
|
-
return bprop
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
@bprop_getters.register(P.IOU)
|
|
46
|
-
def get_bprop_iou(self):
|
|
47
|
-
"""Generate bprop for IOU"""
|
|
48
|
-
|
|
49
|
-
def bprop(x, y, out, dout):
|
|
50
|
-
return zeros_like(x), zeros_like(y)
|
|
51
|
-
return bprop
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
@bprop_getters.register(inner.SyncBatchNorm)
|
|
55
|
-
def get_bprop_sync_batch_norm(self):
|
|
56
|
-
"""Grad definition for `SyncBatchNorm` operation."""
|
|
57
|
-
input_grad = G.SyncBatchNormGrad(self.epsilon, self.group, self.device_num)
|
|
58
|
-
|
|
59
|
-
def bprop(x, scale, b, mean, variance, out, dout):
|
|
60
|
-
saved_mean = out[3]
|
|
61
|
-
saved_variance = out[4]
|
|
62
|
-
out = input_grad(dout[0], x, scale, saved_mean, saved_variance)
|
|
63
|
-
dx = out[0]
|
|
64
|
-
dscale = out[1]
|
|
65
|
-
dbias = out[2]
|
|
66
|
-
return dx, dscale, dbias, zeros_like(mean), zeros_like(variance)
|
|
67
|
-
return bprop
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
@bprop_getters.register(inner.GpuConvertToDynamicShape)
|
|
71
|
-
def get_bprop_gpu_convert_to_dynamic_shape(self):
|
|
72
|
-
"""Get backprop for GpuConvertToDynamicShape."""
|
|
73
|
-
|
|
74
|
-
def bprop(x, out, dout):
|
|
75
|
-
return (dout,)
|
|
76
|
-
return bprop
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
@bprop_getters.register(P._DynamicLossScale) # pylint: disable=W0212
|
|
80
|
-
def get_bprop_dynamic_loss_scale(self):
|
|
81
|
-
"""Get backprop for dynamic_loss_scale."""
|
|
82
|
-
mul = P.Mul()
|
|
83
|
-
mul.add_prim_attr('split_overflow', True)
|
|
84
|
-
mul.add_prim_attr('layer_overflow', self.layer)
|
|
85
|
-
|
|
86
|
-
def bprop(x, loss_scale, out, dout):
|
|
87
|
-
res = mul(dout, loss_scale)
|
|
88
|
-
return res, zeros_like(loss_scale)
|
|
89
|
-
return bprop
|
|
@@ -1,296 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
|
|
16
|
-
"""grad_sequence_ops"""
|
|
17
|
-
|
|
18
|
-
from mindspore.ops.operations import _sequence_ops as seq
|
|
19
|
-
from mindspore.ops import operations as P
|
|
20
|
-
from mindspore.ops import functional as F
|
|
21
|
-
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
22
|
-
from mindspore.ops._grad.grad_base import bprop_getters
|
|
23
|
-
from mindspore.ops.primitive import Primitive
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
tuple_setitem = Primitive("tuple_setitem")
|
|
27
|
-
list_setitem = Primitive("list_setitem")
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
@bprop_getters.register(seq.SequenceCount)
|
|
31
|
-
def get_bprop_count(self):
|
|
32
|
-
"""Generate bprop for SequenceCount"""
|
|
33
|
-
|
|
34
|
-
def bprop(x, y, out, dout):
|
|
35
|
-
return (zeros_like(x), zeros_like(y))
|
|
36
|
-
|
|
37
|
-
return bprop
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
@bprop_getters.register(seq.sequence_len)
|
|
41
|
-
def get_bprop_sequence_len(self):
|
|
42
|
-
"""Generate bprop for sequence_len"""
|
|
43
|
-
def bprop(x, out, dout):
|
|
44
|
-
return (zeros_like(x),)
|
|
45
|
-
|
|
46
|
-
return bprop
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
@bprop_getters.register(seq.SequenceAdd)
|
|
50
|
-
def get_bprop_sequence_add(self):
|
|
51
|
-
"""Generate bprop for SequenceAdd"""
|
|
52
|
-
def bprop(x, y, out, dout):
|
|
53
|
-
out_offset = seq.SequenceAddOffset()(x, y)
|
|
54
|
-
dx = seq.SequenceSlice()(dout, out_offset[0], len(x), 1)
|
|
55
|
-
dy = seq.SequenceSlice()(dout, out_offset[1], len(x) + len(y), 1)
|
|
56
|
-
|
|
57
|
-
return (dx, dy)
|
|
58
|
-
|
|
59
|
-
return bprop
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
@bprop_getters.register(seq.SequenceSlice)
|
|
63
|
-
def get_bprop_slice(self):
|
|
64
|
-
"""Generate bprop for SequenceSlice"""
|
|
65
|
-
|
|
66
|
-
def bprop(x, start, stop, step, out, dout):
|
|
67
|
-
dx = seq.SequenceSliceGrad()(dout, x, start, stop, step)
|
|
68
|
-
return (dx, zeros_like(start), zeros_like(stop), zeros_like(step))
|
|
69
|
-
|
|
70
|
-
return bprop
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
@bprop_getters.register(seq.SequenceIndex)
|
|
74
|
-
def get_bprop_index(self):
|
|
75
|
-
"""Generate bprop for SequenceIndex"""
|
|
76
|
-
|
|
77
|
-
def bprop(x, y, start, end, out, dout):
|
|
78
|
-
return (zeros_like(x), zeros_like(y), zeros_like(start), zeros_like(end))
|
|
79
|
-
|
|
80
|
-
return bprop
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
@bprop_getters.register(seq.InSequence)
|
|
84
|
-
def get_bprop_insequence(self):
|
|
85
|
-
"""Generate bprop for InSequence"""
|
|
86
|
-
|
|
87
|
-
def bprop(x, y, out, dout):
|
|
88
|
-
return (zeros_like(x), seq.SequenceZerosLike()(y))
|
|
89
|
-
|
|
90
|
-
return bprop
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
@bprop_getters.register("tuple_equal")
|
|
94
|
-
@bprop_getters.register("list_equal")
|
|
95
|
-
def get_bprop_seq_equal(self):
|
|
96
|
-
"""Generate bprop for tuple_equal and list_equal"""
|
|
97
|
-
|
|
98
|
-
def bprop(x, y, out, dout):
|
|
99
|
-
return (zeros_like(x), zeros_like(y))
|
|
100
|
-
|
|
101
|
-
return bprop
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
@bprop_getters.register("shape_mul")
|
|
105
|
-
def get_bprop_shape_mul(self):
|
|
106
|
-
"""Generate bprop for tuple_equal and list_equal"""
|
|
107
|
-
|
|
108
|
-
def bprop(x, out, dout):
|
|
109
|
-
dx = seq.ShapeMulGrad()(x, dout)
|
|
110
|
-
return (dx,)
|
|
111
|
-
|
|
112
|
-
return bprop
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
@bprop_getters.register("tuple_setitem")
|
|
116
|
-
def get_bprop_tuple_setitem(self):
|
|
117
|
-
"""Generate bprop for TupleSetItem and ListSetItem"""
|
|
118
|
-
|
|
119
|
-
def bprop(x, idx, value, out, dout):
|
|
120
|
-
d_x = tuple_setitem(dout, idx, zeros_like(value))
|
|
121
|
-
d_value = dout[idx]
|
|
122
|
-
d_idx = 0
|
|
123
|
-
return (d_x, zeros_like(d_idx), d_value)
|
|
124
|
-
|
|
125
|
-
return bprop
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
@bprop_getters.register("list_setitem")
|
|
129
|
-
def get_bprop_lsit_setitem(self):
|
|
130
|
-
"""Generate bprop for TupleSetItem and ListSetItem"""
|
|
131
|
-
|
|
132
|
-
def bprop(x, idx, value, out, dout):
|
|
133
|
-
d_x = list_setitem(dout, idx, zeros_like(value))
|
|
134
|
-
d_value = dout[idx]
|
|
135
|
-
d_idx = 0
|
|
136
|
-
return (d_x, zeros_like(d_idx), d_value)
|
|
137
|
-
|
|
138
|
-
return bprop
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
@bprop_getters.register(seq.ListAppend)
|
|
142
|
-
def get_bprop_list_append(self):
|
|
143
|
-
"""Generate bprop for ListAppend"""
|
|
144
|
-
|
|
145
|
-
def bprop(x, value, out, dout):
|
|
146
|
-
d_x = seq.ListAppendAndInsertGrad()(dout, -1)
|
|
147
|
-
return (d_x, zeros_like(value))
|
|
148
|
-
|
|
149
|
-
return bprop
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
@bprop_getters.register(seq.ListInsert)
|
|
153
|
-
def get_bprop_list_insert(self):
|
|
154
|
-
"""Generate bprop for ListInsert"""
|
|
155
|
-
|
|
156
|
-
def bprop(x, idx, value, out, dout):
|
|
157
|
-
d_x = seq.ListAppendAndInsertGrad()(dout, idx)
|
|
158
|
-
return (d_x, zeros_like(idx), zeros_like(value))
|
|
159
|
-
|
|
160
|
-
return bprop
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
@bprop_getters.register(seq.TupleToTensor)
|
|
164
|
-
def get_bprop_tuple_to_tensor(self):
|
|
165
|
-
"""Generate bprop for TupleToTensor"""
|
|
166
|
-
|
|
167
|
-
def bprop(x, dtype, out, dout):
|
|
168
|
-
tuple_type = F.typeof(x)
|
|
169
|
-
dout = P.Cast()(dout, tuple_type)
|
|
170
|
-
d_x = seq.TensorToTuple()(dout)
|
|
171
|
-
return (d_x, zeros_like(dtype))
|
|
172
|
-
|
|
173
|
-
return bprop
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
@bprop_getters.register(seq.ListToTensor)
|
|
177
|
-
def get_bprop_list_to_tensor(self):
|
|
178
|
-
"""Generate bprop for ListToTensor"""
|
|
179
|
-
|
|
180
|
-
def bprop(x, dtype, out, dout):
|
|
181
|
-
tuple_type = F.typeof(x)
|
|
182
|
-
dout = P.Cast()(dout, tuple_type)
|
|
183
|
-
d_x = seq.TensorToList()(dout)
|
|
184
|
-
return (d_x, zeros_like(dtype))
|
|
185
|
-
|
|
186
|
-
return bprop
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
@bprop_getters.register(P.ScalarToTensor)
|
|
190
|
-
def get_bprop_scalar_to_tensor(self):
|
|
191
|
-
"""Generate bprop for ScalarToTensor"""
|
|
192
|
-
|
|
193
|
-
def bprop(x, dtype, out, dout):
|
|
194
|
-
scalar_type = F.typeof(x)
|
|
195
|
-
dout = P.Cast()(dout, scalar_type)
|
|
196
|
-
d_x = seq.TensorToScalar()(dout)
|
|
197
|
-
return (d_x, zeros_like(dtype))
|
|
198
|
-
|
|
199
|
-
return bprop
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
@bprop_getters.register(seq.TensorToTuple)
|
|
203
|
-
def get_bprop_tensor_to_tuple(self):
|
|
204
|
-
"""Generate bprop for TensorToTuple"""
|
|
205
|
-
|
|
206
|
-
def bprop(x, out, dout):
|
|
207
|
-
dtype = F.typeof(x)
|
|
208
|
-
d_x = seq.TupleToTensor()(dout, dtype)
|
|
209
|
-
return (d_x,)
|
|
210
|
-
|
|
211
|
-
return bprop
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
@bprop_getters.register(seq.TensorToList)
|
|
215
|
-
def get_bprop_tensor_to_list(self):
|
|
216
|
-
"""Generate bprop for TensorToList"""
|
|
217
|
-
|
|
218
|
-
def bprop(x, out, dout):
|
|
219
|
-
dtype = F.typeof(x)
|
|
220
|
-
d_x = seq.ListToTensor()(dout, dtype)
|
|
221
|
-
return (d_x,)
|
|
222
|
-
|
|
223
|
-
return bprop
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
@bprop_getters.register(seq.TensorToScalar)
|
|
227
|
-
def get_bprop_tensor_to_scalar(self):
|
|
228
|
-
"""Generate bprop for TensorToScalar"""
|
|
229
|
-
|
|
230
|
-
def bprop(x, out, dout):
|
|
231
|
-
dtype = F.typeof(x)
|
|
232
|
-
d_x = P.ScalarToTensor()(dout, dtype)
|
|
233
|
-
return (d_x,)
|
|
234
|
-
|
|
235
|
-
return bprop
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
@bprop_getters.register("tuple_le")
|
|
239
|
-
@bprop_getters.register("tuple_lt")
|
|
240
|
-
@bprop_getters.register("list_le")
|
|
241
|
-
@bprop_getters.register("list_lt")
|
|
242
|
-
def get_bprop_less(self):
|
|
243
|
-
"""Generate bprop for SequenceLessThan and SequenceLessEqual"""
|
|
244
|
-
|
|
245
|
-
def bprop(x, y, out, dout):
|
|
246
|
-
return (zeros_like(x), zeros_like(y))
|
|
247
|
-
|
|
248
|
-
return bprop
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
@bprop_getters.register(seq.SequenceMul)
|
|
252
|
-
def get_bprop_mul(self):
|
|
253
|
-
"""Generate bprop for SequenceMul"""
|
|
254
|
-
|
|
255
|
-
def bprop(x, y, out, dout):
|
|
256
|
-
dx = x
|
|
257
|
-
if isinstance(x, tuple):
|
|
258
|
-
for i in range(len(x)):
|
|
259
|
-
dx = tuple_setitem(dx, i, dout[i])
|
|
260
|
-
else:
|
|
261
|
-
for i in range(len(x)):
|
|
262
|
-
dx = list_setitem(dx, i, dout[i])
|
|
263
|
-
return (dx, zeros_like(y))
|
|
264
|
-
|
|
265
|
-
return bprop
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
@bprop_getters.register(seq.SequenceMin)
|
|
269
|
-
@bprop_getters.register(seq.SequenceMax)
|
|
270
|
-
def get_bprop_max_min(self):
|
|
271
|
-
"""Generate bprop for SequenceMax and SequenceMax"""
|
|
272
|
-
|
|
273
|
-
def bprop(x, out, dout):
|
|
274
|
-
index = x.index(out)
|
|
275
|
-
if isinstance(x, tuple):
|
|
276
|
-
dx = tuple_setitem(zeros_like(x), index, dout)
|
|
277
|
-
else:
|
|
278
|
-
dx = list_setitem(zeros_like(x), index, dout)
|
|
279
|
-
return (dx,)
|
|
280
|
-
|
|
281
|
-
return bprop
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
@bprop_getters.register("tuple_greater_than")
|
|
285
|
-
@bprop_getters.register("list_greater_than")
|
|
286
|
-
@bprop_getters.register("tuple_greater_equal")
|
|
287
|
-
@bprop_getters.register("list_greater_equal")
|
|
288
|
-
def get_bprop_greater(self):
|
|
289
|
-
"""Generate bprop for tuple_greater_than, list_greater_than,
|
|
290
|
-
tuple_greater_equal, list_greater_equal.
|
|
291
|
-
"""
|
|
292
|
-
|
|
293
|
-
def bprop(x, y, out, dout):
|
|
294
|
-
return (zeros_like(x), zeros_like(y))
|
|
295
|
-
|
|
296
|
-
return bprop
|
|
@@ -1,323 +0,0 @@
|
|
|
1
|
-
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
|
|
16
|
-
"""bprop primitives"""
|
|
17
|
-
from mindspore.ops._grad.grad_base import bprops, bprop_getters
|
|
18
|
-
from mindspore.ops.composite.multitype_ops._constexpr_utils import infer_out_shape
|
|
19
|
-
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
20
|
-
from mindspore.ops.operations._sparse_grad_ops import SparseAddGrad
|
|
21
|
-
from mindspore.common import dtype as mstype
|
|
22
|
-
from mindspore.ops import functional as F
|
|
23
|
-
from mindspore.ops import operations as P
|
|
24
|
-
from mindspore.ops.operations import _csr_ops
|
|
25
|
-
from mindspore.ops.operations.sparse_ops import SparseAdd, CSRSparseMatrixToDense, CSRSparseMatrixToSparseTensor, \
|
|
26
|
-
DenseToCSRSparseMatrix
|
|
27
|
-
from mindspore.ops.operations.sparse_ops import SparseToDenseV2
|
|
28
|
-
|
|
29
|
-
# Unused parameters are placeholders.
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
# COOTensor Bprop Methods
|
|
33
|
-
|
|
34
|
-
@bprops.register("MakeCOOTensor")
|
|
35
|
-
def bprop_make_coo_tensor(indices, values, dense_shape, out, dout):
|
|
36
|
-
"""Backpropagator for primitive `MakeCOOTensor`."""
|
|
37
|
-
return (zeros_like(indices), dout.values,)
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
@bprops.register("COOTensorGetIndices")
|
|
41
|
-
def bprop_coo_tensor_get_indices(coo_tensor, out, dout):
|
|
42
|
-
"""Backpropagator for primitive `COOTensorGetIndices`."""
|
|
43
|
-
return (F.make_coo_tensor(dout, zeros_like(coo_tensor.values), coo_tensor.shape),)
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
@bprops.register("COOTensorGetValues")
|
|
47
|
-
def bprop_coo_tensor_get_values(coo_tensor, out, dout):
|
|
48
|
-
"""Backpropagator for primitive `COOTensorGetValues`."""
|
|
49
|
-
return (F.make_coo_tensor(zeros_like(coo_tensor.indices), dout, coo_tensor.shape),)
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
@bprops.register("COOTensorGetDenseShape")
|
|
53
|
-
def bprop_coo_tensor_get_dense_shape(coo_tensor, out, dout):
|
|
54
|
-
"""Backpropagator for primitive `COOTensorGetDenseShape`."""
|
|
55
|
-
return (zeros_like(coo_tensor),)
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
@bprop_getters.register(P.SparseToDense)
|
|
59
|
-
def get_bprop_sparse_to_dense(self):
|
|
60
|
-
"""Generate bprop for SparseToDense"""
|
|
61
|
-
|
|
62
|
-
def bprop(indices, values, dense_shape, out, dout):
|
|
63
|
-
return zeros_like(indices), F.gather_nd(dout, indices), zeros_like(dense_shape)
|
|
64
|
-
|
|
65
|
-
return bprop
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
@bprop_getters.register(SparseToDenseV2)
|
|
69
|
-
def get_bprop_sparse_to_dense_v2(self):
|
|
70
|
-
"""Generate bprop for SparseToDenseV2"""
|
|
71
|
-
|
|
72
|
-
def bprop(indices, output_shape, values, default_value, out, dout):
|
|
73
|
-
sparse_values_grad = F.gather_nd(dout, indices)
|
|
74
|
-
default_value_grad = F.reduce_sum(dout) - F.reduce_sum(sparse_values_grad)
|
|
75
|
-
result_all = (zeros_like(indices), zeros_like(output_shape), sparse_values_grad, default_value_grad)
|
|
76
|
-
return result_all
|
|
77
|
-
|
|
78
|
-
return bprop
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
@bprop_getters.register(P.SparseTensorDenseMatmul)
|
|
82
|
-
def get_bprop_sparse_tensor_dense_matmul(self):
|
|
83
|
-
"""Generate bprop for SparseTensorDenseMatmul"""
|
|
84
|
-
adj_s = self.adjoint_st
|
|
85
|
-
adj_d = self.adjoint_dt
|
|
86
|
-
sparse_tensor_dense_mat_mul = P.SparseTensorDenseMatmul(not adj_s)
|
|
87
|
-
split = P.Split(-1, 2)
|
|
88
|
-
reduce_sum = P.ReduceSum()
|
|
89
|
-
|
|
90
|
-
def bprop(indices, values, dense_shape, dense, out, dout):
|
|
91
|
-
dense_grad = sparse_tensor_dense_mat_mul(indices, values, dense_shape, dout)
|
|
92
|
-
perm = (1, 0)
|
|
93
|
-
if adj_d:
|
|
94
|
-
dense_grad = F.transpose(dense_grad, perm)
|
|
95
|
-
is_half = False
|
|
96
|
-
if dense.dtype == mstype.float16:
|
|
97
|
-
dense = P.Cast()(dense, mstype.float32)
|
|
98
|
-
dout = P.Cast()(dout, mstype.float32)
|
|
99
|
-
is_half = True
|
|
100
|
-
split_indices = split(indices)
|
|
101
|
-
rows = reduce_sum(split_indices[0], -1)
|
|
102
|
-
cols = reduce_sum(split_indices[1], -1)
|
|
103
|
-
parts_a = F.gather(dout, cols if adj_s else rows, 0)
|
|
104
|
-
parts_b = F.gather(F.transpose(dense, perm) if adj_d else dense, rows if adj_s else cols, 0)
|
|
105
|
-
values_grad = F.reduce_sum(parts_a * parts_b, -1)
|
|
106
|
-
if is_half:
|
|
107
|
-
values_grad = P.Cast()(values_grad, mstype.float16)
|
|
108
|
-
return zeros_like(indices), values_grad, zeros_like(dense_shape), dense_grad
|
|
109
|
-
return bprop
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
@bprop_getters.register(SparseAdd)
|
|
113
|
-
def get_bprop_sparse_add(self):
|
|
114
|
-
"""Generate bprop for SparseAdd"""
|
|
115
|
-
sparse_add_grad = SparseAddGrad()
|
|
116
|
-
shape_op = P.Shape()
|
|
117
|
-
dyn_shape_op = P.TensorShape()
|
|
118
|
-
reshape = P.Reshape()
|
|
119
|
-
|
|
120
|
-
def bprop(x1_indices, x1_values, x1_shape, x2_indices, x2_values, x2_shape, thresh, out, dout):
|
|
121
|
-
dx1, dx2 = sparse_add_grad(dout[1], x1_indices, x2_indices, out[0])
|
|
122
|
-
ret0 = zeros_like(x1_indices)
|
|
123
|
-
shp = shape_op(x1_values)
|
|
124
|
-
if F.is_sequence_value_unknown(shp):
|
|
125
|
-
shp = dyn_shape_op(x1_values)
|
|
126
|
-
dx1_shape = shp
|
|
127
|
-
ret1 = reshape(dx1, dx1_shape)
|
|
128
|
-
ret2 = zeros_like(x1_shape)
|
|
129
|
-
|
|
130
|
-
ret3 = zeros_like(x2_indices)
|
|
131
|
-
shp = shape_op(x2_values)
|
|
132
|
-
if F.is_sequence_value_unknown(shp):
|
|
133
|
-
shp = dyn_shape_op(x2_values)
|
|
134
|
-
dx2_shape = shp
|
|
135
|
-
ret4 = reshape(dx2, dx2_shape)
|
|
136
|
-
ret5 = zeros_like(x2_shape)
|
|
137
|
-
|
|
138
|
-
ret6 = zeros_like(thresh)
|
|
139
|
-
ret = (ret0, ret1, ret2, ret3, ret4, ret5, ret6,)
|
|
140
|
-
return ret
|
|
141
|
-
return bprop
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
# CSRTensor Bprop Methods
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
@bprops.register("MakeCSRTensor")
|
|
148
|
-
def bprop_make_csr_tensor(indptr, indices, values, dense_shape, out, dout):
|
|
149
|
-
"""Backpropagator for primitive `MakeCSRTensor`."""
|
|
150
|
-
res = (zeros_like(indptr), zeros_like(indices), dout.values, dout.shape)
|
|
151
|
-
return res
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
@bprops.register("CSRTensorGetIndptr")
|
|
155
|
-
def bprop_csr_tensor_get_indptr(csr_tensor, out, dout):
|
|
156
|
-
"""Backpropagator for primitive `CSRTensorGetIndptr`."""
|
|
157
|
-
return (F.make_csr_tensor(dout, zeros_like(csr_tensor.indices), zeros_like(csr_tensor.values), csr_tensor.shape),)
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
@bprops.register("CSRTensorGetIndices")
|
|
161
|
-
def bprop_csr_tensor_get_indices(csr_tensor, out, dout):
|
|
162
|
-
"""Backpropagator for primitive `CSRTensorGetIndices`."""
|
|
163
|
-
return (F.make_csr_tensor(zeros_like(csr_tensor.indptr), dout, zeros_like(csr_tensor.values), csr_tensor.shape),)
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
@bprops.register("CSRTensorGetValues")
|
|
167
|
-
def bprop_csr_tensor_get_values(csr_tensor, out, dout):
|
|
168
|
-
"""Backpropagator for primitive `CSRTensorGetValues`."""
|
|
169
|
-
return (F.make_csr_tensor(zeros_like(csr_tensor.indptr), zeros_like(csr_tensor.indices), dout, csr_tensor.shape),)
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
@bprops.register("CSRTensorGetDenseShape")
|
|
173
|
-
def bprop_csr_tensor_get_dense_shape(csr_tensor, out, dout):
|
|
174
|
-
"""Backpropagator for primitive `CSRTensorGetDenseShape`."""
|
|
175
|
-
return (zeros_like(csr_tensor),)
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
@bprop_getters.register(_csr_ops.CSRReduceSum)
|
|
179
|
-
def get_bprop_csr_reduce_sum(self):
|
|
180
|
-
"Back-propagation for CSRReduceSum."
|
|
181
|
-
def bprop(indptr, indices, values, shape, axis, out, dout):
|
|
182
|
-
output_shape_kept_dims = F.reduced_shape(shape, axis)
|
|
183
|
-
tile_scaling = F.tuple_div(shape, output_shape_kept_dims)
|
|
184
|
-
values_grad_dense = F.tile(F.reshape(dout, output_shape_kept_dims), tile_scaling)
|
|
185
|
-
values_grad = F.csr_gather(indptr, indices, values_grad_dense, shape)
|
|
186
|
-
res = (indptr, indices, values_grad, (), zeros_like(axis))
|
|
187
|
-
return res
|
|
188
|
-
return bprop
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
@bprop_getters.register(_csr_ops.CSRMV)
|
|
192
|
-
def get_bprop_csr_mv(self):
|
|
193
|
-
"Back-propagation for CSRMV."
|
|
194
|
-
def bprop(indptr, indices, values, dense_shape, dense, out, dout):
|
|
195
|
-
rows = F.csr2coo(indptr, indices.shape[0])
|
|
196
|
-
idx_dtype = rows.dtype
|
|
197
|
-
rows_transposed, cols_indexing = F.sort(indices.astype(mstype.float32))
|
|
198
|
-
rows_transposed = rows_transposed.astype(idx_dtype)
|
|
199
|
-
cols_transposed = rows[cols_indexing]
|
|
200
|
-
values_transposed = values[cols_indexing]
|
|
201
|
-
indptr_transposed = F.coo2csr(rows_transposed, dense_shape[1])
|
|
202
|
-
csr_tensor_transposed = F.make_csr_tensor(
|
|
203
|
-
indptr_transposed, cols_transposed, values_transposed, (dense_shape[1], dense_shape[0]))
|
|
204
|
-
|
|
205
|
-
dense_grad = F.csr_mv(csr_tensor_transposed, dout)
|
|
206
|
-
parts_a = F.gather(dout, rows, 0)
|
|
207
|
-
parts_b = F.gather(dense, indices, 0)
|
|
208
|
-
values_grad = F.reduce_sum(parts_a * parts_b, 1)
|
|
209
|
-
res = (indptr, indices, values_grad, (), dense_grad)
|
|
210
|
-
return res
|
|
211
|
-
return bprop
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
@bprop_getters.register(_csr_ops.CSRMul)
|
|
215
|
-
def get_bprop_csr_mul(self):
|
|
216
|
-
"""
|
|
217
|
-
Back-propagation for CSRMul.
|
|
218
|
-
Note: Broadcast of first dimension of the dense input is not supported for `CSRDiv`,
|
|
219
|
-
because this would require sparse reduce sum on the first axis, which is not logically contiguous
|
|
220
|
-
for the CSR storage format. If broadcast of first dimension should be desired, the operator `/`
|
|
221
|
-
could be used instead, which bypass the constraint by making use of the indices in the CSR input
|
|
222
|
-
to index the dense input.
|
|
223
|
-
"""
|
|
224
|
-
def bprop(indptr, indices, values, shape, dense, out, dout):
|
|
225
|
-
csr_tensor_grad_value = F.csr_mul(F.make_csr_tensor(indptr, indices, dout, shape), dense).values
|
|
226
|
-
dense_grad_value = F.mul(dout, values)
|
|
227
|
-
dense_grad = F.make_csr_tensor(indptr, indices, dense_grad_value, shape)
|
|
228
|
-
if len(dense.shape) == 1 or dense.shape[0] == 1:
|
|
229
|
-
raise ValueError(
|
|
230
|
-
"Backpropagation for CSRMul with broadcast for the first dimension is not supported! Use `*` instead")
|
|
231
|
-
if dense.shape[1] == 1:
|
|
232
|
-
dense_grad = F.csr_reduce_sum(dense_grad, 1)
|
|
233
|
-
else:
|
|
234
|
-
row = F.csr2coo(indptr, indices.shape[0])
|
|
235
|
-
coo_idx = P.Stack(-1)((row, indices))
|
|
236
|
-
dense_grad = F.tensor_scatter_update(zeros_like(dense), coo_idx, dense_grad_value)
|
|
237
|
-
res = (indptr, indices, csr_tensor_grad_value, (), dense_grad)
|
|
238
|
-
return res
|
|
239
|
-
return bprop
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
@bprop_getters.register(_csr_ops.CSRDiv)
|
|
243
|
-
def get_bprop_csr_div(self):
|
|
244
|
-
"""
|
|
245
|
-
Back-propagation for CSRDiv.
|
|
246
|
-
Note: Broadcast of first dimension of the dense input is not supported for `CSRDiv`,
|
|
247
|
-
because this would require sparse reduce sum on the first axis, which is not logically contiguous
|
|
248
|
-
for the CSR storage format. If broadcast of first dimension should be desired, the operator `/`
|
|
249
|
-
could be used instead, which bypass the constraint by making use of the indices in the CSR input
|
|
250
|
-
to index the dense input.
|
|
251
|
-
"""
|
|
252
|
-
def bprop(indptr, indices, values, shape, dense, out, dout):
|
|
253
|
-
batch_dim_csr_start = 2
|
|
254
|
-
batch_dim_dense_start = len(dense.shape) - (len(shape) - batch_dim_csr_start)
|
|
255
|
-
if batch_dim_dense_start < 0:
|
|
256
|
-
batch_dim_dense_start = 0
|
|
257
|
-
feature_dim = infer_out_shape(shape[:batch_dim_csr_start], dense.shape[:batch_dim_dense_start])
|
|
258
|
-
|
|
259
|
-
shape_x = feature_dim + shape[batch_dim_csr_start:]
|
|
260
|
-
shape_y = feature_dim + shape[batch_dim_dense_start:]
|
|
261
|
-
reduce_x, reduce_y = F.broadcast_gradient_args(shape_x, shape_y)
|
|
262
|
-
|
|
263
|
-
csr_tensor_grad = F.csr_div(F.make_csr_tensor(indptr, indices, dout, shape), dense)
|
|
264
|
-
if reduce_x:
|
|
265
|
-
csr_tensor_grad_value = P.ReduceSum(True)(csr_tensor_grad.values, reduce_x)
|
|
266
|
-
else:
|
|
267
|
-
csr_tensor_grad_value = csr_tensor_grad.values
|
|
268
|
-
dense_grad_value = F.neg_tensor(F.mul(out, csr_tensor_grad_value))
|
|
269
|
-
dense_grad = F.make_csr_tensor(indptr, indices, dense_grad_value, shape)
|
|
270
|
-
if len(dense.shape) == 1 or dense.shape[0] == 1:
|
|
271
|
-
raise ValueError(
|
|
272
|
-
"Backpropagation for CSRDiv with broadcast for the first dimension is not supported! Use `/` instead")
|
|
273
|
-
if dense.shape[1] == 1:
|
|
274
|
-
dense_grad = F.csr_reduce_sum(dense_grad, 1)
|
|
275
|
-
else:
|
|
276
|
-
row = F.csr2coo(indptr, indices.shape[0])
|
|
277
|
-
coo_idx = P.Stack(-1)((row, indices))
|
|
278
|
-
dense_grad = F.tensor_scatter_update(zeros_like(dense), coo_idx, dense_grad_value)
|
|
279
|
-
if reduce_y:
|
|
280
|
-
dense_grad = P.ReduceSum(True)(csr_tensor_grad_value, reduce_y)
|
|
281
|
-
res = (indptr, indices, csr_tensor_grad_value, (), dense_grad)
|
|
282
|
-
return res
|
|
283
|
-
return bprop
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
@bprop_getters.register(_csr_ops.CSR2COO)
|
|
287
|
-
def get_bprop_csr2coo(self):
|
|
288
|
-
"Back-propagation for CSR2COO."
|
|
289
|
-
def bprop(indptr, nnz, out, dout):
|
|
290
|
-
return zeros_like(indptr), zeros_like(nnz)
|
|
291
|
-
return bprop
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
@bprop_getters.register(_csr_ops.COO2CSR)
|
|
295
|
-
def get_bprop_coo2csr(self):
|
|
296
|
-
"Back-propagation for COO2CSR."
|
|
297
|
-
def bprop(row_indices, height, out, dout):
|
|
298
|
-
return zeros_like(row_indices), zeros_like(height)
|
|
299
|
-
return bprop
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
csr_to_coo = CSRSparseMatrixToSparseTensor()
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
dense_to_csr = DenseToCSRSparseMatrix()
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
@bprops.register(CSRSparseMatrixToDense)
|
|
309
|
-
def bprop_csr_sparse_matrix_to_dense(shape, batch, indptr, indices, values, out, dout):
|
|
310
|
-
"""Backpropagator for primitive `CSRSparseMatrixToDense`."""
|
|
311
|
-
index, _, _ = csr_to_coo(shape, batch, indptr, indices, values)
|
|
312
|
-
return dense_to_csr(dout, index)
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
csr_to_dense = CSRSparseMatrixToDense()
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
@bprops.register(DenseToCSRSparseMatrix)
|
|
319
|
-
def bprop_dense_to_csr_sparse_matrix(dense_input, indices, out, dout):
|
|
320
|
-
"""Backpropagator for primitive `DenseToCSRSparseMatrix`."""
|
|
321
|
-
shape, batch_ptr, row_ptr, col_ind = out[:4]
|
|
322
|
-
dvalue = dout[4]
|
|
323
|
-
return csr_to_dense(shape, batch_ptr, row_ptr, col_ind, dvalue), zeros_like(indices)
|