mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
|
@@ -19,38 +19,28 @@ from __future__ import absolute_import
|
|
|
19
19
|
from mindspore import Tensor
|
|
20
20
|
from mindspore.ops.primitive import constexpr
|
|
21
21
|
from mindspore.common import dtype as mstype
|
|
22
|
-
from mindspore.
|
|
23
|
-
from mindspore.ops.
|
|
24
|
-
from mindspore.ops.
|
|
25
|
-
from mindspore.ops._grad.grad_base import convert_to_tensor
|
|
22
|
+
from mindspore.ops._grad_experimental.grad_math_ops import binop_grad_common
|
|
23
|
+
from mindspore.ops._grad_experimental.grad_base import bprop_getters, dyn_ones
|
|
24
|
+
from mindspore.ops._grad_experimental.grad_base import convert_to_tensor, create_tensor_by_element
|
|
26
25
|
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
27
|
-
from mindspore.ops.operations.array_ops import Tril
|
|
28
26
|
from mindspore.ops.operations.array_ops import MatrixDiagV3
|
|
29
27
|
from mindspore.ops.operations.array_ops import MatrixDiagPartV3
|
|
30
28
|
from mindspore.ops.operations.array_ops import ResizeNearestNeighborV2
|
|
31
29
|
from mindspore.ops.operations.array_ops import MatrixSetDiagV3
|
|
30
|
+
from mindspore.ops.operations.array_ops import MatrixBandPart
|
|
32
31
|
from mindspore.ops.operations.array_ops import Mvlgamma
|
|
33
|
-
from mindspore.ops.operations.array_ops import Triu
|
|
34
|
-
from mindspore.ops.operations.array_ops import IdentityN
|
|
35
32
|
from mindspore.ops.operations.array_ops import IndexFill
|
|
36
33
|
from mindspore.ops.operations.array_ops import IndexPut
|
|
37
|
-
from mindspore.ops.operations.array_ops import CheckNumerics
|
|
38
|
-
from mindspore.ops.operations.array_ops import ConjugateTranspose
|
|
39
|
-
from mindspore.ops.operations.array_ops import SegmentMax
|
|
40
|
-
from mindspore.ops.operations.array_ops import SegmentMin
|
|
41
34
|
from mindspore.ops.operations.array_ops import SegmentSum
|
|
42
|
-
from mindspore.ops.operations.array_ops import TensorScatterElements
|
|
43
35
|
from mindspore.ops.operations.array_ops import ScatterAddWithAxis
|
|
44
36
|
from mindspore.ops.operations.array_ops import Expand
|
|
45
37
|
from mindspore.ops.operations.array_ops import SegmentMean
|
|
46
38
|
from mindspore.ops.operations.array_ops import AffineGrid
|
|
47
39
|
from mindspore.ops.operations.array_ops import Im2Col
|
|
48
40
|
from mindspore.ops.operations.array_ops import Col2Im
|
|
49
|
-
from mindspore.ops.operations.array_ops import StridedSliceV2
|
|
50
41
|
from mindspore.ops.operations.array_ops import MaskedScatter
|
|
51
42
|
from mindspore.ops.operations.array_ops import MaskedSelect
|
|
52
43
|
from mindspore.ops.operations.array_ops import CountNonZero
|
|
53
|
-
from mindspore.ops.operations._grad_ops import StridedSliceV2Grad
|
|
54
44
|
from mindspore.ops.operations.random_ops import LogNormalReverse
|
|
55
45
|
from mindspore.ops.operations.random_ops import ParameterizedTruncatedNormal
|
|
56
46
|
from mindspore.ops.operations import _inner_ops as inner
|
|
@@ -58,6 +48,18 @@ from mindspore.ops import functional as F
|
|
|
58
48
|
from mindspore.ops import operations as P
|
|
59
49
|
from mindspore.ops.operations import _grad_ops as G
|
|
60
50
|
from mindspore import context
|
|
51
|
+
from mindspore.ops.primitive import _primexpr
|
|
52
|
+
from mindspore.common.sparse_tensor import RowTensorInner
|
|
53
|
+
from mindspore.ops._utils.utils import generate_shape_index
|
|
54
|
+
|
|
55
|
+
reduce_sum = P.ReduceSum()
|
|
56
|
+
unsorted_segment_sum = P.UnsortedSegmentSum()
|
|
57
|
+
transpose = P.Transpose()
|
|
58
|
+
shape_op = P.Shape()
|
|
59
|
+
reshape = P.Reshape()
|
|
60
|
+
size_op = P.Size()
|
|
61
|
+
invert_permutation = P.InvertPermutation()
|
|
62
|
+
logical_and = P.LogicalAnd()
|
|
61
63
|
|
|
62
64
|
|
|
63
65
|
@constexpr
|
|
@@ -68,91 +70,28 @@ def _raise_value_error(*info):
|
|
|
68
70
|
raise ValueError(info_str)
|
|
69
71
|
|
|
70
72
|
|
|
71
|
-
@bprop_getters.register(P.FillV2)
|
|
72
|
-
def get_bprop_fill_v2(self):
|
|
73
|
-
"""Generate bprop for FillV2"""
|
|
74
|
-
sum_op = P.ReduceSum()
|
|
75
|
-
cast_op = P.Cast()
|
|
76
|
-
shape_op = P.TensorShape()
|
|
77
|
-
|
|
78
|
-
def bprop(shape, value, out, dout):
|
|
79
|
-
dout_type = F.dtype(dout)
|
|
80
|
-
type_list = [
|
|
81
|
-
mstype.int8, mstype.int16, mstype.int32, mstype.int64,
|
|
82
|
-
mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64,
|
|
83
|
-
mstype.float16, mstype.float64
|
|
84
|
-
]
|
|
85
|
-
if dout_type in type_list:
|
|
86
|
-
dout = cast_op(dout, mstype.float32)
|
|
87
|
-
dout_shape = shape_op(dout)
|
|
88
|
-
axis = tuple([i for i in range(len(dout_shape))])
|
|
89
|
-
dvalue = sum_op(dout, axis)
|
|
90
|
-
return zeros_like(shape), cast_op(dvalue, dout_type)
|
|
91
|
-
|
|
92
|
-
return bprop
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
@bprop_getters.register(StridedSliceV2)
|
|
96
|
-
def get_bprop_strided_slice_v2(self):
|
|
97
|
-
"""Generate bprop for StridedSliceV2"""
|
|
98
|
-
shape_op = P.Shape()
|
|
99
|
-
dyn_shape_op = P.TensorShape()
|
|
100
|
-
input_grad = StridedSliceV2Grad(self.begin_mask,
|
|
101
|
-
self.end_mask,
|
|
102
|
-
self.ellipsis_mask,
|
|
103
|
-
self.new_axis_mask,
|
|
104
|
-
self.shrink_axis_mask)
|
|
105
|
-
|
|
106
|
-
def bprop(x, begin, end, strides, out, dout):
|
|
107
|
-
x_shape = shape_op(x)
|
|
108
|
-
if F.is_sequence_value_unknown(x_shape):
|
|
109
|
-
x_shape = dyn_shape_op(x)
|
|
110
|
-
dx = input_grad(x_shape, begin, end, strides, dout)
|
|
111
|
-
dx_all = (dx, zeros_like(begin), zeros_like(end), zeros_like(strides))
|
|
112
|
-
return dx_all
|
|
113
|
-
|
|
114
|
-
return bprop
|
|
115
|
-
|
|
116
|
-
|
|
117
73
|
@constexpr
|
|
118
74
|
def _create_tensor(data, dtype):
|
|
119
75
|
return Tensor(data, dtype=dtype)
|
|
120
76
|
|
|
121
77
|
|
|
122
|
-
def _segment_min_or_max_grad(segment_sum_op, input_x, segment_ids, output, dout):
|
|
123
|
-
"""Calculate the gradient of SegmentMax or SegmentMin"""
|
|
124
|
-
gather = P.Gather()
|
|
125
|
-
equal = P.Equal()
|
|
126
|
-
cast = P.Cast()
|
|
127
|
-
divide = P.Div()
|
|
128
|
-
input_x_type = F.dtype(input_x)
|
|
129
|
-
input_x = cast(input_x, mstype.float32)
|
|
130
|
-
output = cast(output, mstype.float32)
|
|
131
|
-
dout = cast(dout, mstype.float32)
|
|
132
|
-
zeros = zeros_like(input_x)
|
|
133
|
-
gathered_outputs = gather(output, segment_ids, 0)
|
|
134
|
-
is_selected = equal(input_x, gathered_outputs)
|
|
135
|
-
num_selected = segment_sum_op(cast(is_selected, F.dtype(dout)), segment_ids)
|
|
136
|
-
weighted_grads = divide(dout, num_selected)
|
|
137
|
-
gathered_grads = gather(weighted_grads, segment_ids, 0)
|
|
138
|
-
return cast(where(is_selected, gathered_grads, zeros), input_x_type), zeros_like(segment_ids)
|
|
139
|
-
|
|
140
|
-
|
|
141
78
|
@bprop_getters.register(P.MaskedFill)
|
|
142
79
|
def get_bprop_masked_select(self):
|
|
143
80
|
"""Generate bprop for MaskedFill"""
|
|
144
81
|
mul_op = P.Mul()
|
|
145
82
|
sum_op = P.ReduceSum()
|
|
146
83
|
is_instance_op = inner.IsInstance()
|
|
84
|
+
rank = P.Rank()
|
|
147
85
|
|
|
148
86
|
def bprop(input_data, mask, value, out, dout):
|
|
149
87
|
mask = F.cast(mask, mstype.float32)
|
|
88
|
+
dout = F.cast(dout, mstype.float32)
|
|
150
89
|
dinput = mul_op(dout, (1 - mask))
|
|
151
90
|
dvalue = mul_op(dout, mask)
|
|
152
91
|
dinput, dvalue = binop_grad_common(input_data, mask, dinput, dvalue)
|
|
153
92
|
# for dynamic rank, reduce axis should be calc
|
|
154
93
|
if F.is_sequence_shape_unknown(P.Shape()(dvalue)):
|
|
155
|
-
axis =
|
|
94
|
+
axis = range(0, rank(dvalue), 1)
|
|
156
95
|
dvalue = sum_op(dvalue, axis)
|
|
157
96
|
else:
|
|
158
97
|
dvalue = sum_op(dvalue)
|
|
@@ -169,37 +108,21 @@ def get_bprop_masked_select(self):
|
|
|
169
108
|
@bprop_getters.register(MaskedScatter)
|
|
170
109
|
def get_bprop_masked_scatter(self):
|
|
171
110
|
"""Generate bprop for MaskedScatter"""
|
|
172
|
-
sort_ = P.Sort(descending=True)
|
|
173
|
-
masked_scatter = MaskedScatter()
|
|
174
111
|
masked_fill = P.MaskedFill()
|
|
175
112
|
masked_select = P.MaskedSelect()
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
reshape = P.Reshape()
|
|
180
|
-
shape = P.Shape()
|
|
181
|
-
|
|
113
|
+
shape = P.TensorShape()
|
|
114
|
+
range_ = P.Range()
|
|
115
|
+
scatter_update = P.TensorScatterElements()
|
|
182
116
|
def bprop(x, mask, updates, out, dout):
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
if diff_num > 0:
|
|
193
|
-
zeros_pad = zeros(diff_num, F.dtype(mask))
|
|
194
|
-
mask_sorted = concat((mask_sorted, zeros_pad))
|
|
195
|
-
zeros_tensor = zeros(size(updates), mstype.float32)
|
|
196
|
-
dupdates = masked_scatter(zeros_tensor, mask_sorted, mask_selected)
|
|
197
|
-
if shape(updates) != ():
|
|
198
|
-
dupdates = reshape(dupdates, shape(updates))
|
|
199
|
-
else:
|
|
200
|
-
zeros_tensor = zeros(shape(updates), mstype.float32)
|
|
201
|
-
dupdates = masked_scatter(zeros_tensor, mask, mask_selected)
|
|
202
|
-
return F.cast(dx, F.dtype(x)), zeros_like(mask), F.cast(dupdates, F.dtype(updates))
|
|
117
|
+
dout = F.cast(dout, mstype.float32)
|
|
118
|
+
dx = masked_fill(dout, mask, F.cast(0, mstype.float32))
|
|
119
|
+
dupdates = F.cast(zeros_like(updates).reshape(-1), mstype.float32)
|
|
120
|
+
dupdates_val = F.cast(masked_select(dout, mask), mstype.float32)
|
|
121
|
+
length = F.cast(shape(dupdates_val)[0], mstype.int32)
|
|
122
|
+
scatter_indices = range_(F.cast(0, mstype.int32), length, F.cast(1, mstype.int32))
|
|
123
|
+
dupdates = scatter_update(dupdates, scatter_indices, dupdates_val)
|
|
124
|
+
dupdates = reshape(dupdates, shape(updates))
|
|
125
|
+
return F.cast(dx, x.dtype), zeros_like(mask), F.cast(dupdates, updates.dtype)
|
|
203
126
|
|
|
204
127
|
return bprop
|
|
205
128
|
|
|
@@ -226,43 +149,19 @@ def get_bprop_mvlgamma(self):
|
|
|
226
149
|
return bprop
|
|
227
150
|
|
|
228
151
|
|
|
229
|
-
@bprop_getters.register(P.TensorScatterDiv)
|
|
230
|
-
def get_bprop_tensor_scatter_div(self):
|
|
231
|
-
"""Generate bprop for TensorScatterDiv"""
|
|
232
|
-
gather_nd = P.GatherNd()
|
|
233
|
-
tensor_scatter_div = P.TensorScatterDiv()
|
|
234
|
-
neg = P.Neg()
|
|
235
|
-
div = P.Div()
|
|
236
|
-
mul = P.Mul()
|
|
237
|
-
|
|
238
|
-
def bprop(x, indices, update, out, dout):
|
|
239
|
-
# (input)' / update
|
|
240
|
-
in_grad = tensor_scatter_div(dout, indices, update)
|
|
241
|
-
|
|
242
|
-
# - (input * (update)') / (update * update)
|
|
243
|
-
gather_update = gather_nd(dout, indices)
|
|
244
|
-
gather_x = gather_nd(x, indices)
|
|
245
|
-
mul_result = mul(update, update)
|
|
246
|
-
neg_result = neg(mul_result)
|
|
247
|
-
update_grad = gather_update * div(gather_x, neg_result)
|
|
248
|
-
|
|
249
|
-
return in_grad, zeros_like(indices), update_grad
|
|
250
|
-
|
|
251
|
-
return bprop
|
|
252
|
-
|
|
253
|
-
|
|
254
152
|
@bprop_getters.register(IndexFill)
|
|
255
153
|
def get_bprop_index_fill(self):
|
|
256
154
|
"""Generate bprop for IndexFill"""
|
|
257
155
|
gather = P.Gather()
|
|
258
156
|
index_fill = IndexFill()
|
|
259
157
|
shape = P.Shape()
|
|
158
|
+
rank = P.Rank()
|
|
260
159
|
|
|
261
160
|
def bprop(x, dim, indices, value, out, dout):
|
|
262
161
|
zero_value = zeros_like(value)
|
|
263
162
|
x_grad = index_fill(dout, dim, indices, zero_value)
|
|
264
163
|
if F.is_sequence_value_unknown(shape(x)):
|
|
265
|
-
if
|
|
164
|
+
if rank(x) == 0:
|
|
266
165
|
value_grad = dout
|
|
267
166
|
else:
|
|
268
167
|
value_grad = gather(dout, indices, dim).sum()
|
|
@@ -286,6 +185,8 @@ def get_bprop_index_put(self):
|
|
|
286
185
|
masked_select = MaskedSelect()
|
|
287
186
|
masked_scatter = MaskedScatter()
|
|
288
187
|
accumulate_grad = self.accumulate
|
|
188
|
+
equal = P.Equal()
|
|
189
|
+
cast = P.Cast()
|
|
289
190
|
index_put = IndexPut(accumulate=accumulate_grad)
|
|
290
191
|
is_ascend = context.get_context("device_target") == 'Ascend'
|
|
291
192
|
|
|
@@ -301,9 +202,10 @@ def get_bprop_index_put(self):
|
|
|
301
202
|
indices_ms = [tile(x, (maxsize,)) if x.shape[0] == 1 else x for x in indices]
|
|
302
203
|
if is_ascend:
|
|
303
204
|
indices_ms = [convert_idx_positive(indices_ms[i], x1.shape[i]) for i in range(len(indices_ms))]
|
|
304
|
-
|
|
205
|
+
indices_me = stack(indices_ms)
|
|
206
|
+
indices_grad = F.transpose(indices_me, F.make_range(F.rank(indices_me)-1, -1, -1))
|
|
305
207
|
values_grad = gather_nd(dout, indices_grad)
|
|
306
|
-
if x2.shape[0]
|
|
208
|
+
if equal(cast(x2.shape[0], mstype.int32), Tensor(1)):
|
|
307
209
|
values_grad = values_grad.sum().reshape(1)
|
|
308
210
|
if values_grad.shape != x2.shape and len(indices) < len(x1.shape):
|
|
309
211
|
_, values_grad = binop_grad_common(x1, x2, dout, values_grad)
|
|
@@ -314,50 +216,6 @@ def get_bprop_index_put(self):
|
|
|
314
216
|
return bprop
|
|
315
217
|
|
|
316
218
|
|
|
317
|
-
@bprop_getters.register(P.TensorScatterSub)
|
|
318
|
-
def get_bprop_tensor_scatter_sub(self):
|
|
319
|
-
"""Generate bprop for TensorScatterSub"""
|
|
320
|
-
gather_nd = P.GatherNd()
|
|
321
|
-
neg = P.Neg()
|
|
322
|
-
|
|
323
|
-
def bprop(x, indices, update, out, dout):
|
|
324
|
-
update_grad = neg(gather_nd(dout, indices))
|
|
325
|
-
return dout, zeros_like(indices), update_grad
|
|
326
|
-
|
|
327
|
-
return bprop
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
@bprop_getters.register(P.TensorScatterMul)
|
|
331
|
-
def get_bprop_tensor_scatter_mul(self):
|
|
332
|
-
"""Generate bprop for TensorScatterMul"""
|
|
333
|
-
gather_nd = P.GatherNd()
|
|
334
|
-
mul_func = P.TensorScatterMul()
|
|
335
|
-
|
|
336
|
-
def bprop(x, indices, update, out, dout):
|
|
337
|
-
gather_update = gather_nd(dout, indices)
|
|
338
|
-
gather_x = gather_nd(x, indices)
|
|
339
|
-
dx = mul_func(dout, indices, update)
|
|
340
|
-
d_update = gather_x * gather_update
|
|
341
|
-
return dx, zeros_like(indices), d_update
|
|
342
|
-
|
|
343
|
-
return bprop
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
@bprop_getters.register(MatrixDiagV3)
|
|
347
|
-
def get_bprop_matrix_diag_v3(self):
|
|
348
|
-
"""Generate bprop for MatrixDiagV3"""
|
|
349
|
-
align = self.align
|
|
350
|
-
matrix_diag_part_v3 = MatrixDiagPartV3(align=align)
|
|
351
|
-
zeros = P.Zeros()
|
|
352
|
-
|
|
353
|
-
def bprop(x, k, num_rows, num_cols, padding_value, out, dout):
|
|
354
|
-
result = (matrix_diag_part_v3(dout, k, zeros((), dout.dtype)), zeros_like(k), zeros_like(num_rows),
|
|
355
|
-
zeros_like(num_cols), zeros_like(padding_value))
|
|
356
|
-
return result
|
|
357
|
-
|
|
358
|
-
return bprop
|
|
359
|
-
|
|
360
|
-
|
|
361
219
|
@bprop_getters.register(MatrixDiagPartV3)
|
|
362
220
|
def get_bprop_matrix_diag_part_v3(self):
|
|
363
221
|
"""Generate bprop for MatrixDiagPartV3"""
|
|
@@ -380,6 +238,17 @@ def get_bprop_matrix_diag_part_v3(self):
|
|
|
380
238
|
return bprop
|
|
381
239
|
|
|
382
240
|
|
|
241
|
+
@bprop_getters.register(MatrixBandPart)
|
|
242
|
+
def get_bprop_matrix_band_part(self):
|
|
243
|
+
"""Grad definition for `MatrixBandPart` operation."""
|
|
244
|
+
matrix_band_part = MatrixBandPart()
|
|
245
|
+
|
|
246
|
+
def bprop(x, lower, upper, out, dout):
|
|
247
|
+
return matrix_band_part(dout, lower, upper), zeros_like(lower), zeros_like(upper)
|
|
248
|
+
|
|
249
|
+
return bprop
|
|
250
|
+
|
|
251
|
+
|
|
383
252
|
@bprop_getters.register(MatrixSetDiagV3)
|
|
384
253
|
def get_bprop_matrix_set_diag_v3(self):
|
|
385
254
|
"""Generate bprop for MatrixSetDiagV3"""
|
|
@@ -409,15 +278,11 @@ def tensor_scatter_possible_replacement(x, indices, updates, out, dout):
|
|
|
409
278
|
scatter_nd = P.ScatterNd()
|
|
410
279
|
equal = P.Equal()
|
|
411
280
|
shape = P.Shape()
|
|
412
|
-
dyn_shape_op = P.TensorShape()
|
|
413
281
|
|
|
414
282
|
x_indicators = F.cast(equal(x, out), mstype.int32)
|
|
415
283
|
possibly_updated = gather_nd(out, indices)
|
|
416
284
|
out_indicators = F.cast(equal(updates, possibly_updated), mstype.int32)
|
|
417
285
|
input_shape = shape(x)
|
|
418
|
-
if F.is_sequence_value_unknown(input_shape):
|
|
419
|
-
input_shape = dyn_shape_op(x)
|
|
420
|
-
|
|
421
286
|
scattered_out_indicators = scatter_nd(indices, out_indicators, input_shape)
|
|
422
287
|
indicators = x_indicators + scattered_out_indicators
|
|
423
288
|
dx = dout * F.cast(x_indicators, F.dtype(dout)) / F.cast(indicators, F.dtype(dout))
|
|
@@ -474,80 +339,16 @@ def get_bprop_coalesce(self):
|
|
|
474
339
|
return bprop
|
|
475
340
|
|
|
476
341
|
|
|
477
|
-
@bprop_getters.register(ConjugateTranspose)
|
|
478
|
-
def get_bprop_conjugate_transpose(self):
|
|
479
|
-
"""Generate bprop for ConjugateTranspose"""
|
|
480
|
-
conjugate_transpose = ConjugateTranspose()
|
|
481
|
-
invert_permutation = P.InvertPermutation()
|
|
482
|
-
|
|
483
|
-
def bprop(x, perm, out, dout):
|
|
484
|
-
return conjugate_transpose(dout, invert_permutation(perm)), zeros_like(perm)
|
|
485
|
-
|
|
486
|
-
return bprop
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
@bprop_getters.register(Triu)
|
|
490
|
-
def get_bprop_triu(self):
|
|
491
|
-
"""Grad definition for 'Triu' operation"""
|
|
492
|
-
diagonal = self.diagonal
|
|
493
|
-
triu = Triu(diagonal)
|
|
494
|
-
|
|
495
|
-
def bprop(x, out, dout):
|
|
496
|
-
dx = triu(dout)
|
|
497
|
-
return (dx,)
|
|
498
|
-
|
|
499
|
-
return bprop
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
@bprop_getters.register(CheckNumerics)
|
|
503
|
-
def get_bprop_check_numerics(self):
|
|
504
|
-
"""Generate bprop for CheckNumerics"""
|
|
505
|
-
check_numerics = CheckNumerics()
|
|
506
|
-
|
|
507
|
-
def bprop(x_input, out, dout):
|
|
508
|
-
return (check_numerics(dout),)
|
|
509
|
-
|
|
510
|
-
return bprop
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
@bprop_getters.register(P.SplitV)
|
|
514
|
-
def get_bprop_split_v(self):
|
|
515
|
-
"""Generate bprop for SplitV"""
|
|
516
|
-
split_dim = self.split_dim
|
|
517
|
-
concat_op = P.Concat(split_dim)
|
|
518
|
-
|
|
519
|
-
def bprop(x_input, output, dout):
|
|
520
|
-
dx = concat_op(dout)
|
|
521
|
-
return (dx,)
|
|
522
|
-
|
|
523
|
-
return bprop
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
@bprop_getters.register(IdentityN)
|
|
527
|
-
def get_bprop_identity_n(self):
|
|
528
|
-
"""Generate bprop for IdentityN"""
|
|
529
|
-
|
|
530
|
-
def bprop(x, out, dout):
|
|
531
|
-
return (dout,)
|
|
532
|
-
|
|
533
|
-
return bprop
|
|
534
|
-
|
|
535
|
-
|
|
536
342
|
@bprop_getters.register(ResizeNearestNeighborV2)
|
|
537
343
|
def get_bprop_resize_nearest_neighbor_v2(self):
|
|
538
344
|
"""Generate bprop for ResizeNearestNeighborV2"""
|
|
539
345
|
align_corners = self.align_corners
|
|
540
346
|
half_pixel_centers = self.half_pixel_centers
|
|
541
|
-
|
|
542
|
-
grad_op = G.ResizeNearestNeighborV2Grad(align_corners, half_pixel_centers, data_format)
|
|
347
|
+
grad_op = G.ResizeNearestNeighborV2Grad(align_corners, half_pixel_centers)
|
|
543
348
|
|
|
544
349
|
def bprop(x, size, output, dout):
|
|
545
350
|
x_shape = P.Shape()(x)
|
|
546
|
-
|
|
547
|
-
x_shape = P.TensorShape()(x)
|
|
548
|
-
grad_in_size = x_shape[1:3]
|
|
549
|
-
if data_format == 'NCHW':
|
|
550
|
-
grad_in_size = x_shape[2:4]
|
|
351
|
+
grad_in_size = x_shape[2:4]
|
|
551
352
|
|
|
552
353
|
if F.is_sequence_value_unknown(P.Shape()(x)):
|
|
553
354
|
dx = grad_op(dout, grad_in_size)
|
|
@@ -559,22 +360,6 @@ def get_bprop_resize_nearest_neighbor_v2(self):
|
|
|
559
360
|
return bprop
|
|
560
361
|
|
|
561
362
|
|
|
562
|
-
@bprop_getters.register(Col2Im)
|
|
563
|
-
def get_bprop_col2im(self):
|
|
564
|
-
"""Generate bprop for Col2Im"""
|
|
565
|
-
ksizes = self.kernel_size
|
|
566
|
-
dilations = self.dilation
|
|
567
|
-
strides = self.stride
|
|
568
|
-
pads = self.padding
|
|
569
|
-
im2col = Im2Col(ksizes=ksizes, dilations=dilations, strides=strides, pads=pads)
|
|
570
|
-
|
|
571
|
-
def bprop(x, output_size, out, dout):
|
|
572
|
-
dx = im2col(dout)
|
|
573
|
-
return dx, zeros_like(output_size)
|
|
574
|
-
|
|
575
|
-
return bprop
|
|
576
|
-
|
|
577
|
-
|
|
578
363
|
@bprop_getters.register(Im2Col)
|
|
579
364
|
def get_bprop_im2col(self):
|
|
580
365
|
"""
|
|
@@ -591,14 +376,13 @@ def get_bprop_im2col(self):
|
|
|
591
376
|
dilation = self.dilations
|
|
592
377
|
stride = self.strides
|
|
593
378
|
padding = (self.pads[0], self.pads[-1])
|
|
594
|
-
shape_op = P.TensorShape()
|
|
595
379
|
col2im = Col2Im(kernel_size=kernel_size,
|
|
596
380
|
dilation=dilation,
|
|
597
381
|
stride=stride,
|
|
598
382
|
padding=padding)
|
|
599
383
|
|
|
600
384
|
def bprop(x, out, dout):
|
|
601
|
-
x_shape =
|
|
385
|
+
x_shape = P.TensorShape()(x)[2:]
|
|
602
386
|
dx = col2im(dout, x_shape)
|
|
603
387
|
return (dx,)
|
|
604
388
|
|
|
@@ -614,18 +398,16 @@ def get_bprop_extract_volume_patches(self):
|
|
|
614
398
|
expend_dims = P.ExpandDims()
|
|
615
399
|
scatter_nd = P.ScatterNd()
|
|
616
400
|
slice_op = P.Slice()
|
|
617
|
-
fill = P.Fill()
|
|
618
401
|
dtype = P.DType()
|
|
619
402
|
cast = P.Cast()
|
|
620
403
|
matmul = P.MatMul()
|
|
621
404
|
_, _, ksize_d, ksize_h, ksize_w = self.kernel_size
|
|
622
405
|
range_ = P.Range()
|
|
623
|
-
dyn_shape_op = P.TensorShape()
|
|
624
406
|
ones_like = P.OnesLike()
|
|
625
407
|
|
|
626
408
|
def _dyn_extract_volume_patches(x, out, dout):
|
|
627
|
-
x_shape =
|
|
628
|
-
out_shape =
|
|
409
|
+
x_shape = shape_op(x)
|
|
410
|
+
out_shape = shape_op(out)
|
|
629
411
|
x_n, x_c, x_d, x_h, x_w = x_shape[0], x_shape[1], x_shape[2], x_shape[3], x_shape[4]
|
|
630
412
|
x_indices_num = 1 + x_d * x_h * x_w
|
|
631
413
|
x_idx = range_(cast(1, mstype.float32), cast(x_indices_num, mstype.float32), cast(1, mstype.float32))
|
|
@@ -683,7 +465,7 @@ def get_bprop_extract_volume_patches(self):
|
|
|
683
465
|
idx_tensor = concat((expend_dims(x_idx_patched, -1), expend_dims(out_idx, -1)))
|
|
684
466
|
idx_map = P.Reshape()(idx_tensor, (-1, 2))
|
|
685
467
|
sp_shape = (x_indices_num, out_indices_num)
|
|
686
|
-
sp_mat_full = scatter_nd(idx_map, fill(dtype(dout), (out_indices_num,), 1), sp_shape)
|
|
468
|
+
sp_mat_full = scatter_nd(idx_map, F.fill(dtype(dout), (out_indices_num,), 1), sp_shape)
|
|
687
469
|
sp_tensor = slice_op(sp_mat_full, (1, 0), (x_indices_num - 1, out_indices_num))
|
|
688
470
|
|
|
689
471
|
grad = P.Transpose()(dout, (0, 2, 3, 4, 1))
|
|
@@ -700,19 +482,6 @@ def get_bprop_extract_volume_patches(self):
|
|
|
700
482
|
return bprop
|
|
701
483
|
|
|
702
484
|
|
|
703
|
-
@bprop_getters.register(Tril)
|
|
704
|
-
def get_bprop_tril(self):
|
|
705
|
-
"""Grad definition for 'Tril' operation"""
|
|
706
|
-
diagonal = self.diagonal
|
|
707
|
-
tril = Tril(diagonal)
|
|
708
|
-
|
|
709
|
-
def bprop(x, out, dout):
|
|
710
|
-
dx = tril(dout)
|
|
711
|
-
return (dx,)
|
|
712
|
-
|
|
713
|
-
return bprop
|
|
714
|
-
|
|
715
|
-
|
|
716
485
|
@bprop_getters.register(SegmentSum)
|
|
717
486
|
def get_bprop_segment_sum(self):
|
|
718
487
|
"""Generate bprop for SegmentSum"""
|
|
@@ -738,16 +507,13 @@ def get_bprop_affinegrid(self):
|
|
|
738
507
|
align_corners = self.align_corners
|
|
739
508
|
input_grad = G.AffineGridGrad(align_corners)
|
|
740
509
|
ones = P.Ones()
|
|
741
|
-
transpose = P.Transpose()
|
|
742
510
|
concat = P.Concat(1)
|
|
743
511
|
concat0 = P.Concat(0)
|
|
744
512
|
tile = P.Tile()
|
|
745
513
|
div = P.Div()
|
|
746
|
-
reshape = P.Reshape()
|
|
747
514
|
linspace = P.LinSpace()
|
|
748
515
|
batmatmul = P.BatchMatMul()
|
|
749
516
|
expend_dims = P.ExpandDims()
|
|
750
|
-
dyn_shape = P.TensorShape()
|
|
751
517
|
reducesum = P.ReduceSum(keep_dims=False)
|
|
752
518
|
|
|
753
519
|
def get_linspace(num):
|
|
@@ -846,7 +612,7 @@ def get_bprop_affinegrid(self):
|
|
|
846
612
|
return transpose(dtheta, perm2), tre
|
|
847
613
|
|
|
848
614
|
def dyn_bprop(theta, output_size, out, dout):
|
|
849
|
-
len_output_size = reducesum(
|
|
615
|
+
len_output_size = reducesum(shape_op(output_size))
|
|
850
616
|
dtheta = dyn_ones(Tensor([1, 3, 2], mstype.int32), mstype.float32)
|
|
851
617
|
ret = dyn_ones(Tensor([1, 6], mstype.int32), mstype.float32)
|
|
852
618
|
if len_output_size == 5:
|
|
@@ -968,44 +734,6 @@ def get_bprop_affinegrid(self):
|
|
|
968
734
|
return bprop
|
|
969
735
|
|
|
970
736
|
|
|
971
|
-
@bprop_getters.register(SegmentMax)
|
|
972
|
-
def get_bprop_segment_max(self):
|
|
973
|
-
"""Generate bprop for SegmentMax"""
|
|
974
|
-
segment_sum = SegmentSum()
|
|
975
|
-
|
|
976
|
-
def bprop(input_x, segment_ids, output, dout):
|
|
977
|
-
return _segment_min_or_max_grad(segment_sum, input_x, segment_ids, output, dout)
|
|
978
|
-
|
|
979
|
-
return bprop
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
@bprop_getters.register(SegmentMin)
|
|
983
|
-
def get_bprop_segment_min(self):
|
|
984
|
-
"""Generate bprop for SegmentMin"""
|
|
985
|
-
segment_sum = SegmentSum()
|
|
986
|
-
|
|
987
|
-
def bprop(input_x, segment_ids, output, dout):
|
|
988
|
-
return _segment_min_or_max_grad(segment_sum, input_x, segment_ids, output, dout)
|
|
989
|
-
|
|
990
|
-
return bprop
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
@bprop_getters.register(TensorScatterElements)
|
|
994
|
-
def get_bprop_tensor_scatter_elements(self):
|
|
995
|
-
"""Generate bprop for TensorScatterElements"""
|
|
996
|
-
gather_d = P.GatherD()
|
|
997
|
-
axis = self.axis
|
|
998
|
-
reduction = self.reduction
|
|
999
|
-
tensor_scatter_elements = TensorScatterElements(axis, reduction)
|
|
1000
|
-
|
|
1001
|
-
def bprop(x, indices, update, out, dout):
|
|
1002
|
-
x_grad = tensor_scatter_elements(dout, indices, zeros_like(update))
|
|
1003
|
-
update_grad = gather_d(dout, axis, indices)
|
|
1004
|
-
return x_grad, zeros_like(indices), update_grad
|
|
1005
|
-
|
|
1006
|
-
return bprop
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
737
|
@bprop_getters.register(ScatterAddWithAxis)
|
|
1010
738
|
def get_bprop_scatter_add_with_axis(self):
|
|
1011
739
|
"""Generate bprop for ScatterAddWithAxis"""
|
|
@@ -1066,38 +794,193 @@ def get_bprop_segment_mean(self):
|
|
|
1066
794
|
"""Generate bprop for SegmentMean"""
|
|
1067
795
|
rank = P.Rank()
|
|
1068
796
|
shape = P.Shape()
|
|
1069
|
-
|
|
1070
|
-
fill = P.Fill()
|
|
797
|
+
fill = P.FillV2()
|
|
1071
798
|
divide = P.Div()
|
|
1072
799
|
segment_sum = SegmentSum()
|
|
1073
800
|
gather = P.Gather()
|
|
1074
801
|
cast = P.Cast()
|
|
1075
|
-
concat = P.Concat()
|
|
1076
|
-
expand_dims = P.ExpandDims()
|
|
1077
802
|
|
|
1078
803
|
def bprop(input_x, segment_ids, output, dout):
|
|
1079
804
|
input_x_type = F.dtype(input_x)
|
|
1080
805
|
input_x = cast(input_x, mstype.float32)
|
|
1081
806
|
dout = cast(dout, mstype.float32)
|
|
1082
807
|
dout_type = F.dtype(dout)
|
|
1083
|
-
|
|
1084
808
|
ones_shape = shape(segment_ids)
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
ones = ()
|
|
1089
|
-
inputx_shape = shape(input_x)
|
|
1090
|
-
if F.is_sequence_value_unknown(inputx_shape):
|
|
1091
|
-
input_rank = dyn_rank(input_x)
|
|
1092
|
-
if input_rank > cast(1, mstype.float32):
|
|
1093
|
-
ones_shape = concat([ones_shape, dyn_ones(expand_dims(input_rank - 1, 0), mstype.int64)])
|
|
1094
|
-
ones = dyn_fill(dout_type, ones_shape, 1)
|
|
1095
|
-
else:
|
|
1096
|
-
input_rank = rank(input_x)
|
|
1097
|
-
ones_shape = ones_shape + (1,) * (input_rank - 1)
|
|
1098
|
-
ones = fill(dout_type, ones_shape, 1)
|
|
1099
|
-
|
|
809
|
+
input_rank = rank(input_x)
|
|
810
|
+
ones_shape = ones_shape + (1,) * (input_rank - 1)
|
|
811
|
+
ones = fill(ones_shape, Tensor(1, dout_type))
|
|
1100
812
|
scaled_grad = divide(dout, segment_sum(ones, segment_ids))
|
|
1101
813
|
return cast(gather(scaled_grad, segment_ids, 0), input_x_type), zeros_like(segment_ids)
|
|
1102
814
|
|
|
1103
815
|
return bprop
|
|
816
|
+
|
|
817
|
+
|
|
818
|
+
@bprop_getters.register(P.Ones)
|
|
819
|
+
def get_bprop_ones(self):
|
|
820
|
+
"""Generate bprop for Ones"""
|
|
821
|
+
|
|
822
|
+
def bprop(dims, dtype, out, dout):
|
|
823
|
+
return zeros_like(dims)
|
|
824
|
+
|
|
825
|
+
return bprop
|
|
826
|
+
|
|
827
|
+
|
|
828
|
+
@bprop_getters.register(P.Zeros)
|
|
829
|
+
def get_bprop_zeros(self):
|
|
830
|
+
"""Generate bprop for Zeros"""
|
|
831
|
+
|
|
832
|
+
def bprop(dims, dtype, out, dout):
|
|
833
|
+
return zeros_like(dims)
|
|
834
|
+
|
|
835
|
+
return bprop
|
|
836
|
+
|
|
837
|
+
|
|
838
|
+
@bprop_getters.register(P.EmbeddingLookup)
|
|
839
|
+
def get_bprop_embedding_lookup(self):
|
|
840
|
+
"""Generate bprop for EmbeddingLookup"""
|
|
841
|
+
sub_op = P.Sub()
|
|
842
|
+
reshape_op = P.Reshape()
|
|
843
|
+
|
|
844
|
+
def bprop_sparse(x, indices, offset, out, dout):
|
|
845
|
+
x_shp = shape_op(x)
|
|
846
|
+
if F.is_sequence_value_unknown(x_shp):
|
|
847
|
+
raise RuntimeError("Now, EmbeddingLookup op's grad don't support Dynamic Sense!")
|
|
848
|
+
new_indices = sub_op(indices, offset)
|
|
849
|
+
indices_size = size_op(new_indices)
|
|
850
|
+
if indices_size > 0:
|
|
851
|
+
# Reshape the 'new_indices'
|
|
852
|
+
new_indices_shape_changed = (indices_size,)
|
|
853
|
+
new_indices = reshape_op(new_indices, new_indices_shape_changed)
|
|
854
|
+
else:
|
|
855
|
+
new_indices_shape_changed = ()
|
|
856
|
+
x_shp_tail = x_shp[1:]
|
|
857
|
+
actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail
|
|
858
|
+
# Reshape the 'actual_dout' on device
|
|
859
|
+
actual_dout = reshape_op(dout, actual_dout_shape_changed)
|
|
860
|
+
return RowTensorInner(new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset)
|
|
861
|
+
|
|
862
|
+
return bprop_sparse
|
|
863
|
+
|
|
864
|
+
|
|
865
|
+
@_primexpr
|
|
866
|
+
def _generate_inverse_index(x_shape, axis, batch_dims=0):
|
|
867
|
+
x_rank = len(x_shape)
|
|
868
|
+
index = tuple(range(x_rank))
|
|
869
|
+
if axis < 0:
|
|
870
|
+
axis += x_rank
|
|
871
|
+
perm = index[:batch_dims] + index[batch_dims + 1:1 + axis] + (index[batch_dims],) + index[1 + axis:]
|
|
872
|
+
return perm
|
|
873
|
+
|
|
874
|
+
|
|
875
|
+
@bprop_getters.register(P.SparseGatherV2)
|
|
876
|
+
def get_bprop_sparse_gather_v2(self):
|
|
877
|
+
"""Generate bprop for SparseGatherV2"""
|
|
878
|
+
|
|
879
|
+
def bprop(x, indices, axis, out, dout):
|
|
880
|
+
x_shp = shape_op(x)
|
|
881
|
+
if axis == 0:
|
|
882
|
+
indices_size = (size_op(indices),)
|
|
883
|
+
if len(x_shp) <= 1:
|
|
884
|
+
x_tail_shp = ()
|
|
885
|
+
else:
|
|
886
|
+
x_tail_shp = x_shp[1:]
|
|
887
|
+
values_shape = indices_size + x_tail_shp
|
|
888
|
+
values = reshape(dout, values_shape)
|
|
889
|
+
indices_new = reshape(indices, indices_size)
|
|
890
|
+
return RowTensorInner(indices_new, values, x_shp), zeros_like(indices), zeros_like(axis)
|
|
891
|
+
if F.rank(dout) == 0:
|
|
892
|
+
dout = P.ExpandDims()(dout, -1)
|
|
893
|
+
if F.rank(indices) == 0:
|
|
894
|
+
indices = P.ExpandDims()(indices, -1)
|
|
895
|
+
out_shp = shape_op(dout)
|
|
896
|
+
ind_shp = shape_op(indices)
|
|
897
|
+
# Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
|
|
898
|
+
perm_1 = generate_shape_index(out_shp, ind_shp, axis)
|
|
899
|
+
values_transpose = transpose(dout, perm_1)
|
|
900
|
+
params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
|
|
901
|
+
# Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
|
|
902
|
+
perm_2 = _generate_inverse_index(x_shp, axis)
|
|
903
|
+
params_grad = transpose(params_grad, perm_2)
|
|
904
|
+
return params_grad, zeros_like(indices), zeros_like(axis)
|
|
905
|
+
|
|
906
|
+
return bprop
|
|
907
|
+
|
|
908
|
+
|
|
909
|
+
@bprop_getters.register(P.Unstack)
|
|
910
|
+
def get_bprop_unstack(self):
|
|
911
|
+
"""Generate bprop for Unstack"""
|
|
912
|
+
axis = self.axis
|
|
913
|
+
|
|
914
|
+
def bprop(x, out, dout):
|
|
915
|
+
unstack_grad = P.Stack(axis)
|
|
916
|
+
out = unstack_grad(dout)
|
|
917
|
+
return (out,)
|
|
918
|
+
|
|
919
|
+
return bprop
|
|
920
|
+
|
|
921
|
+
|
|
922
|
+
@bprop_getters.register(P.Eye)
|
|
923
|
+
def get_bprop_eye(self):
|
|
924
|
+
"""Generate bprop for Eye"""
|
|
925
|
+
|
|
926
|
+
def bprop(n, m, t, out, dout):
|
|
927
|
+
return zeros_like(n), zeros_like(m), zeros_like(t)
|
|
928
|
+
|
|
929
|
+
return bprop
|
|
930
|
+
|
|
931
|
+
|
|
932
|
+
@bprop_getters.register(P.ScatterNdUpdate)
|
|
933
|
+
def get_bprop_scatter_nd_update(self):
|
|
934
|
+
"""Generate bprop for ScatterNdUpdate"""
|
|
935
|
+
op = P.GatherNd()
|
|
936
|
+
|
|
937
|
+
def bprop(x, indices, update, out, dout):
|
|
938
|
+
return dout, zeros_like(indices), op(dout, indices)
|
|
939
|
+
|
|
940
|
+
return bprop
|
|
941
|
+
|
|
942
|
+
|
|
943
|
+
@bprop_getters.register(P.ScatterNonAliasingAdd)
|
|
944
|
+
def get_bprop_scatter_non_aliasing_add_update(self):
|
|
945
|
+
"""Generate bprop for ScatterNonAliasingAdd"""
|
|
946
|
+
op = P.GatherNd()
|
|
947
|
+
|
|
948
|
+
def bprop(x, indices, update, out, dout):
|
|
949
|
+
return dout, zeros_like(indices), op(dout, indices)
|
|
950
|
+
|
|
951
|
+
return bprop
|
|
952
|
+
|
|
953
|
+
|
|
954
|
+
@bprop_getters.register(P.ScatterUpdate)
|
|
955
|
+
def get_bprop_scatter_update(self):
|
|
956
|
+
"""Generate bprop for ScatterUpdate"""
|
|
957
|
+
gather = P.Gather()
|
|
958
|
+
|
|
959
|
+
def bprop(x, indices, update, out, dout):
|
|
960
|
+
return dout, zeros_like(indices), gather(dout, indices, 0)
|
|
961
|
+
|
|
962
|
+
return bprop
|
|
963
|
+
|
|
964
|
+
|
|
965
|
+
@bprop_getters.register(P.TransShape)
|
|
966
|
+
def get_bprop_trans_shape(self):
|
|
967
|
+
"""Generate bprop for TransShape"""
|
|
968
|
+
op = P.TransShape()
|
|
969
|
+
|
|
970
|
+
def bprop(x, shape, out, dout):
|
|
971
|
+
dx = op(dout, shape_op(x))
|
|
972
|
+
return (dx, zeros_like(shape))
|
|
973
|
+
|
|
974
|
+
return bprop
|
|
975
|
+
|
|
976
|
+
|
|
977
|
+
@bprop_getters.register(P.Unique)
|
|
978
|
+
def get_bprop_unique(self):
|
|
979
|
+
"""Generate bprop for Unique"""
|
|
980
|
+
op = G.UniqueGrad()
|
|
981
|
+
|
|
982
|
+
def bprop(x, out, dout):
|
|
983
|
+
dx = op(dout, out)
|
|
984
|
+
return (dx,)
|
|
985
|
+
|
|
986
|
+
return bprop
|