mindspore 1.10.0__cp38-cp38-win_amd64.whl → 2.0.0rc1__cp38-cp38-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/ConcurrencyCheck.dll +0 -0
- mindspore/CppBuildInsights.dll +0 -0
- mindspore/CppCoreCheck.dll +0 -0
- mindspore/EnumIndex.dll +0 -0
- mindspore/EspXEngine.dll +0 -0
- mindspore/HResultCheck.dll +0 -0
- mindspore/KernelTraceControl.dll +0 -0
- mindspore/LocalESPC.dll +0 -0
- mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
- mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
- mindspore/VariantClear.dll +0 -0
- mindspore/__init__.py +9 -4
- mindspore/_c_dataengine.cp38-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp38-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp38-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/builtin_operations.py +32 -4
- mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +17 -2
- mindspore/_extends/parse/parser.py +193 -34
- mindspore/_extends/parse/resources.py +7 -8
- mindspore/_extends/parse/standard_method.py +1780 -435
- mindspore/_extends/parse/trope.py +3 -1
- mindspore/amp.py +53 -58
- mindspore/atlprov.dll +0 -0
- mindspore/boost/adasum.py +3 -2
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +46 -26
- mindspore/boost/dim_reduce.py +6 -5
- mindspore/boost/grad_accumulation.py +2 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/cfgpersist.dll +0 -0
- mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
- mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
- mindspore/common/__init__.py +11 -10
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +57 -0
- mindspore/common/api.py +582 -297
- mindspore/common/dtype.py +66 -18
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +38 -1
- mindspore/common/jit_config.py +25 -13
- mindspore/common/mutable.py +53 -24
- mindspore/common/parameter.py +60 -37
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +927 -0
- mindspore/common/tensor.py +1627 -3900
- mindspore/communication/__init__.py +10 -5
- mindspore/communication/_comm_helper.py +78 -214
- mindspore/communication/_hccl_management.py +2 -1
- mindspore/communication/management.py +136 -47
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +291 -56
- mindspore/d3dcompiler_47.dll +0 -0
- mindspore/dataset/__init__.py +12 -8
- mindspore/dataset/audio/__init__.py +9 -9
- mindspore/dataset/audio/transforms.py +1090 -228
- mindspore/dataset/audio/utils.py +87 -39
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +17 -15
- mindspore/dataset/core/config.py +246 -17
- mindspore/dataset/core/py_util_helpers.py +4 -3
- mindspore/dataset/core/validator_helpers.py +10 -10
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +9 -9
- mindspore/dataset/engine/datasets.py +648 -477
- mindspore/dataset/engine/datasets_audio.py +165 -167
- mindspore/dataset/engine/datasets_standard_format.py +93 -67
- mindspore/dataset/engine/datasets_text.py +492 -342
- mindspore/dataset/engine/datasets_user_defined.py +85 -50
- mindspore/dataset/engine/datasets_vision.py +1224 -699
- mindspore/dataset/engine/graphdata.py +134 -69
- mindspore/dataset/engine/iterators.py +50 -9
- mindspore/dataset/engine/offload.py +52 -31
- mindspore/dataset/engine/samplers.py +27 -24
- mindspore/dataset/engine/serializer_deserializer.py +14 -15
- mindspore/dataset/engine/validators.py +213 -52
- mindspore/dataset/text/__init__.py +10 -8
- mindspore/dataset/text/transforms.py +152 -57
- mindspore/dataset/text/utils.py +98 -49
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +4 -2
- mindspore/dataset/transforms/c_transforms.py +11 -13
- mindspore/dataset/transforms/py_transforms.py +2 -2
- mindspore/dataset/transforms/py_transforms_util.py +10 -0
- mindspore/dataset/transforms/transforms.py +13 -15
- mindspore/dataset/transforms/validators.py +7 -7
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/browse_dataset.py +13 -13
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +8 -7
- mindspore/dataset/vision/c_transforms.py +125 -126
- mindspore/dataset/vision/py_transforms.py +37 -37
- mindspore/dataset/vision/py_transforms_util.py +23 -20
- mindspore/dataset/vision/transforms.py +316 -315
- mindspore/dataset/vision/utils.py +313 -17
- mindspore/dataset/vision/validators.py +6 -6
- mindspore/default_config.py +0 -1
- mindspore/dpcmi.dll +0 -0
- mindspore/{compression → experimental}/__init__.py +6 -5
- mindspore/experimental/map_parameter.py +275 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +70 -9
- mindspore/include/api/delegate.h +8 -1
- mindspore/include/api/dual_abi_helper.h +8 -24
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_group.h +68 -0
- mindspore/include/api/model_parallel_runner.h +17 -17
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +20 -4
- mindspore/include/api/status.h +7 -1
- mindspore/include/api/types.h +25 -21
- mindspore/include/api/visible.h +4 -0
- mindspore/include/c_api/model_c.h +5 -0
- mindspore/include/c_api/status_c.h +1 -1
- mindspore/include/dataset/config.h +1 -1
- mindspore/include/dataset/constants.h +14 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/include/dataset/vision.h +56 -117
- mindspore/include/dataset/vision_lite.h +102 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +28 -28
- mindspore/mindrecord/common/exceptions.py +2 -4
- mindspore/mindrecord/filereader.py +19 -1
- mindspore/mindrecord/filewriter.py +250 -88
- mindspore/mindrecord/mindpage.py +13 -13
- mindspore/mindrecord/shardheader.py +15 -15
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +29 -29
- mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
- mindspore/mindrecord/tools/csv_to_mr.py +4 -4
- mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
- mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/{libmindspore_backend.dll → mindspore_backend.dll} +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +1 -5
- mindspore/nn/cell.py +297 -234
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +17 -42
- mindspore/nn/layer/__init__.py +7 -4
- mindspore/nn/layer/activation.py +131 -88
- mindspore/nn/layer/basic.py +313 -613
- mindspore/nn/layer/channel_shuffle.py +103 -0
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +52 -6
- mindspore/nn/layer/conv.py +112 -43
- mindspore/nn/layer/dense.py +10 -9
- mindspore/nn/layer/embedding.py +36 -34
- mindspore/nn/layer/image.py +123 -27
- mindspore/nn/layer/math.py +108 -107
- mindspore/nn/layer/normalization.py +212 -366
- mindspore/nn/layer/padding.py +370 -42
- mindspore/nn/layer/pooling.py +1443 -219
- mindspore/nn/layer/rnn_cells.py +11 -16
- mindspore/nn/layer/rnns.py +38 -39
- mindspore/nn/layer/thor_layer.py +24 -25
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +9 -6
- mindspore/nn/loss/loss.py +678 -142
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
- mindspore/nn/optim/ada_grad.py +8 -8
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +18 -14
- mindspore/nn/optim/adam.py +429 -87
- mindspore/nn/optim/adamax.py +5 -6
- mindspore/nn/optim/adasum.py +10 -8
- mindspore/nn/optim/asgd.py +7 -7
- mindspore/nn/optim/ftrl.py +81 -11
- mindspore/nn/optim/lamb.py +7 -8
- mindspore/nn/optim/lars.py +4 -4
- mindspore/nn/optim/lazyadam.py +82 -7
- mindspore/nn/optim/momentum.py +8 -7
- mindspore/nn/optim/optimizer.py +19 -10
- mindspore/nn/optim/proximal_ada_grad.py +6 -5
- mindspore/nn/optim/rmsprop.py +3 -3
- mindspore/nn/optim/rprop.py +20 -16
- mindspore/nn/optim/sgd.py +21 -15
- mindspore/nn/optim/thor.py +23 -21
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -6
- mindspore/nn/probability/bijector/invert.py +4 -2
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/__init__.py +6 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
- mindspore/nn/probability/distribution/_utils/utils.py +11 -17
- mindspore/nn/probability/distribution/bernoulli.py +6 -6
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +9 -9
- mindspore/nn/probability/distribution/cauchy.py +8 -8
- mindspore/nn/probability/distribution/distribution.py +12 -6
- mindspore/nn/probability/distribution/exponential.py +5 -5
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +6 -5
- mindspore/nn/probability/distribution/gumbel.py +5 -5
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +0 -1
- mindspore/nn/probability/distribution/logistic.py +4 -5
- mindspore/nn/probability/distribution/normal.py +11 -15
- mindspore/nn/probability/distribution/poisson.py +6 -2
- mindspore/nn/probability/distribution/student_t.py +150 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +5 -5
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +8 -1
- mindspore/nn/wrap/cell_wrapper.py +55 -27
- mindspore/nn/wrap/grad_reducer.py +20 -11
- mindspore/nn/wrap/loss_scale.py +47 -30
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +46 -42
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +26 -19
- mindspore/numpy/utils.py +1 -8
- mindspore/numpy/utils_const.py +112 -62
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -3
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +209 -152
- mindspore/ops/_grad/grad_base.py +55 -17
- mindspore/ops/_grad/grad_clip_ops.py +11 -3
- mindspore/ops/_grad/grad_comm_ops.py +58 -47
- mindspore/ops/_grad/grad_implementations.py +21 -61
- mindspore/ops/_grad/grad_inner_ops.py +48 -6
- mindspore/ops/_grad/grad_math_ops.py +306 -161
- mindspore/ops/_grad/grad_nn_ops.py +192 -181
- mindspore/ops/_grad/grad_other_ops.py +1 -1
- mindspore/ops/_grad/grad_quant_ops.py +5 -5
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +15 -9
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
- mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
- mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
- mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
- mindspore/ops/_op_impl/__init__.py +3 -3
- mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
- mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
- mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/__init__.py +1 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
- mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -608
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/greater.py +2 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
- mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
- mindspore/ops/_op_impl/tbe/slice.py +26 -15
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +3 -2
- mindspore/ops/_register_for_op.py +11 -0
- mindspore/ops/_utils/__init__.py +1 -1
- mindspore/ops/_utils/utils.py +20 -41
- mindspore/ops/_vmap/__init__.py +2 -2
- mindspore/ops/_vmap/vmap_array_ops.py +170 -78
- mindspore/ops/_vmap/vmap_base.py +24 -10
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
- mindspore/ops/_vmap/vmap_image_ops.py +52 -0
- mindspore/ops/_vmap/vmap_math_ops.py +77 -6
- mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
- mindspore/ops/_vmap/vmap_other_ops.py +3 -1
- mindspore/ops/_vmap/vmap_random_ops.py +55 -3
- mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
- mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/__init__.py +1 -4
- mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
- mindspore/ops/composite/__init__.py +12 -13
- mindspore/ops/composite/base.py +261 -254
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +197 -156
- mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
- mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
- mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
- mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
- mindspore/ops/function/__init__.py +323 -8
- mindspore/ops/function/array_func.py +3511 -780
- mindspore/ops/function/clip_func.py +329 -0
- mindspore/ops/function/debug_func.py +6 -6
- mindspore/ops/function/grad/__init__.py +5 -1
- mindspore/ops/function/grad/grad_func.py +736 -65
- mindspore/ops/function/image_func.py +270 -0
- mindspore/ops/function/linalg_func.py +268 -8
- mindspore/ops/function/math_func.py +8032 -3164
- mindspore/ops/function/nn_func.py +5619 -1855
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +11 -10
- mindspore/ops/function/random_func.py +939 -77
- mindspore/ops/function/sparse_func.py +249 -84
- mindspore/ops/function/sparse_unary_func.py +2303 -0
- mindspore/ops/function/spectral_func.py +146 -0
- mindspore/ops/function/vmap_func.py +114 -0
- mindspore/ops/functional.py +182 -254
- mindspore/ops/op_info_register.py +79 -34
- mindspore/ops/operations/__init__.py +210 -118
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +25 -15
- mindspore/ops/operations/_grad_ops.py +447 -322
- mindspore/ops/operations/_inner_ops.py +547 -176
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +29 -27
- mindspore/ops/operations/_ocr_ops.py +11 -11
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_quant_ops.py +186 -101
- mindspore/ops/operations/_rl_inner_ops.py +122 -61
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1047 -0
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +4 -4
- mindspore/ops/operations/array_ops.py +1428 -1226
- mindspore/ops/operations/comm_ops.py +180 -117
- mindspore/ops/operations/control_ops.py +4 -2
- mindspore/ops/operations/custom_ops.py +185 -98
- mindspore/ops/operations/debug_ops.py +92 -54
- mindspore/ops/operations/image_ops.py +406 -211
- mindspore/ops/operations/inner_ops.py +42 -53
- mindspore/ops/operations/linalg_ops.py +32 -29
- mindspore/ops/operations/math_ops.py +2076 -897
- mindspore/ops/operations/nn_ops.py +1282 -1252
- mindspore/ops/operations/other_ops.py +124 -278
- mindspore/ops/operations/random_ops.py +345 -178
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +502 -157
- mindspore/ops/operations/spectral_ops.py +107 -0
- mindspore/ops/primitive.py +192 -15
- mindspore/ops/vm_impl_registry.py +23 -2
- mindspore/parallel/__init__.py +6 -1
- mindspore/parallel/_auto_parallel_context.py +199 -92
- mindspore/parallel/_cell_wrapper.py +4 -2
- mindspore/parallel/_cost_model_context.py +3 -0
- mindspore/parallel/_dp_allreduce_fusion.py +2 -1
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +167 -28
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +9 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
- mindspore/parallel/_utils.py +47 -7
- mindspore/parallel/algo_parameter_config.py +5 -1
- mindspore/parallel/checkpoint_transform.py +329 -0
- mindspore/parallel/shard.py +229 -0
- mindspore/perf_msvcbuildinsights.dll +0 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/util.py +4 -3
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +249 -0
- mindspore/profiler/parser/aicpu_data_parser.py +38 -39
- mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
- mindspore/profiler/parser/base_timeline_generator.py +471 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
- mindspore/profiler/parser/framework_parser.py +42 -16
- mindspore/profiler/parser/hccl_parser.py +158 -158
- mindspore/profiler/parser/hwts_log_parser.py +7 -6
- mindspore/profiler/parser/integrator.py +18 -1579
- mindspore/profiler/parser/minddata_analyzer.py +8 -8
- mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +108 -0
- mindspore/profiler/parser/step_trace_parser.py +1 -1
- mindspore/profiler/profiling.py +396 -194
- mindspore/rewrite/__init__.py +6 -2
- mindspore/rewrite/api/node.py +51 -110
- mindspore/rewrite/api/node_type.py +10 -6
- mindspore/rewrite/api/pattern_engine.py +51 -7
- mindspore/rewrite/api/scoped_value.py +64 -53
- mindspore/rewrite/api/symbol_tree.py +108 -61
- mindspore/rewrite/api/tree_node_helper.py +2 -3
- mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
- mindspore/rewrite/ast_helpers/__init__.py +6 -3
- mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
- mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
- mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
- mindspore/rewrite/ast_transformers/__init__.py +0 -1
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
- mindspore/rewrite/common/__init__.py +2 -0
- mindspore/rewrite/common/event.py +1 -1
- mindspore/rewrite/common/observable.py +1 -1
- mindspore/rewrite/common/observer.py +1 -1
- mindspore/rewrite/common/rewrite_elog.py +35 -0
- mindspore/rewrite/namer.py +2 -2
- mindspore/rewrite/namespace.py +14 -4
- mindspore/rewrite/node.py +161 -13
- mindspore/rewrite/parser.py +0 -1
- mindspore/rewrite/parser_register.py +0 -1
- mindspore/rewrite/parsers/arguments_parser.py +3 -2
- mindspore/rewrite/parsers/assign_parser.py +267 -67
- mindspore/rewrite/parsers/attribute_parser.py +56 -0
- mindspore/rewrite/parsers/class_def_parser.py +191 -108
- mindspore/rewrite/parsers/constant_parser.py +101 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/for_parser.py +28 -15
- mindspore/rewrite/parsers/function_def_parser.py +21 -5
- mindspore/rewrite/parsers/if_parser.py +11 -28
- mindspore/rewrite/parsers/module_parser.py +9 -6
- mindspore/rewrite/parsers/return_parser.py +3 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +322 -109
- mindspore/rewrite/symbol_tree_builder.py +45 -8
- mindspore/rewrite/symbol_tree_dumper.py +0 -1
- mindspore/rewrite/topological_manager.py +1 -2
- mindspore/run_check/_check_version.py +209 -112
- mindspore/run_check/run_check.py +2 -1
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +6 -4
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +321 -50
- mindspore/train/callback/__init__.py +3 -1
- mindspore/train/callback/_backup_and_restore.py +120 -0
- mindspore/train/callback/_callback.py +8 -8
- mindspore/train/callback/_checkpoint.py +12 -9
- mindspore/train/callback/_early_stop.py +13 -7
- mindspore/train/callback/_history.py +8 -8
- mindspore/train/callback/_lambda_callback.py +6 -6
- mindspore/train/callback/_landscape.py +36 -38
- mindspore/train/callback/_loss_monitor.py +12 -6
- mindspore/train/callback/_lr_scheduler_callback.py +2 -4
- mindspore/train/callback/_on_request_exit.py +212 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
- mindspore/train/callback/_summary_collector.py +27 -19
- mindspore/train/callback/_time_monitor.py +13 -7
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +122 -33
- mindspore/train/dataset_helper.py +28 -87
- mindspore/train/loss_scale_manager.py +4 -7
- mindspore/{nn → train}/metrics/__init__.py +20 -20
- mindspore/{nn → train}/metrics/accuracy.py +12 -10
- mindspore/{nn → train}/metrics/auc.py +4 -4
- mindspore/{nn → train}/metrics/bleu_score.py +4 -4
- mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
- mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
- mindspore/{nn → train}/metrics/dice.py +6 -5
- mindspore/{nn → train}/metrics/error.py +7 -5
- mindspore/{nn → train}/metrics/fbeta.py +9 -7
- mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
- mindspore/{nn → train}/metrics/loss.py +4 -3
- mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/metric.py +6 -5
- mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
- mindspore/{nn → train}/metrics/perplexity.py +5 -4
- mindspore/{nn → train}/metrics/precision.py +5 -4
- mindspore/{nn → train}/metrics/recall.py +5 -4
- mindspore/{nn → train}/metrics/roc.py +7 -6
- mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/topk.py +7 -5
- mindspore/train/mind_ir_pb2.py +339 -32
- mindspore/train/model.py +113 -84
- mindspore/train/serialization.py +547 -167
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -12
- mindspore/train/train_thor/convert_utils.py +7 -1
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/train/train_thor/model_thor.py +0 -4
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +901 -660
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -514
- mindspore/compression/quant/qat.py +0 -636
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/libatomic-1.dll +0 -0
- mindspore/libgcc_s_seh-1.dll +0 -0
- mindspore/libgfortran-4.dll +0 -0
- mindspore/libgomp-1.dll +0 -0
- mindspore/libjpeg-62.dll +0 -0
- mindspore/libmindspore.dll +0 -0
- mindspore/libmindspore_common.dll +0 -0
- mindspore/libmindspore_core.dll +0 -0
- mindspore/libmindspore_glog.dll +0 -0
- mindspore/libnnacl.dll +0 -0
- mindspore/libopencv_core452.dll +0 -0
- mindspore/libopencv_imgcodecs452.dll +0 -0
- mindspore/libopencv_imgproc452.dll +0 -0
- mindspore/libquadmath-0.dll +0 -0
- mindspore/libsqlite3.dll +0 -0
- mindspore/libssp-0.dll +0 -0
- mindspore/libstdc++-6.dll +0 -0
- mindspore/libtinyxml2.dll +0 -0
- mindspore/libturbojpeg.dll +0 -0
- mindspore/libwinpthread-1.dll +0 -0
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -138
- mindspore/nn/probability/dpn/vae/vae.py +0 -122
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
- mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
- mindspore/ops/composite/array_ops.py +0 -210
- mindspore/ops/composite/clip_ops.py +0 -238
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/ops/operations/sponge_ops.py +0 -3531
- mindspore/ops/operations/sponge_update_ops.py +0 -2546
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- mindspore/run_check/_check_deps_version.py +0 -84
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -15,23 +15,24 @@
|
|
|
15
15
|
|
|
16
16
|
"""Define the grad rules of math related operations."""
|
|
17
17
|
|
|
18
|
-
from functools import reduce
|
|
19
18
|
import numpy as np
|
|
20
19
|
import mindspore as ms
|
|
21
20
|
from mindspore import nn
|
|
22
|
-
from
|
|
23
|
-
from
|
|
24
|
-
from
|
|
25
|
-
from
|
|
26
|
-
from
|
|
27
|
-
from .
|
|
28
|
-
from .
|
|
29
|
-
from
|
|
30
|
-
from
|
|
31
|
-
from
|
|
32
|
-
from
|
|
33
|
-
from
|
|
34
|
-
from
|
|
21
|
+
from mindspore.common import Tensor
|
|
22
|
+
from mindspore.common import dtype as mstype
|
|
23
|
+
from mindspore.ops import functional as F
|
|
24
|
+
from mindspore.ops import operations as P
|
|
25
|
+
from mindspore.ops.operations import _grad_ops as G
|
|
26
|
+
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
27
|
+
from mindspore.ops.functional import broadcast_gradient_args, reduced_shape, tuple_div
|
|
28
|
+
from mindspore.ops._grad.grad_base import bprop_getters, create_tensor_by_element, dyn_invert_permutation
|
|
29
|
+
from mindspore.ops._grad.grad_base import convert_to_tensor
|
|
30
|
+
from mindspore.ops._grad.grad_base import sum_grad_reduce_axis, dyn_fill, dyn_rank
|
|
31
|
+
from mindspore.ops._grad.grad_base import dyn_ones, dyn_rank_1d
|
|
32
|
+
from mindspore.ops.primitive import _primexpr
|
|
33
|
+
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
|
|
34
|
+
from mindspore.ops.operations._inner_ops import DynamicBroadcastGradientArgs, IsSubClass, DynamicBroadcastTo
|
|
35
|
+
from mindspore.ops.operations import array_ops as A
|
|
35
36
|
|
|
36
37
|
shape_op = P.Shape()
|
|
37
38
|
dyn_shape_op = P.TensorShape()
|
|
@@ -39,7 +40,7 @@ reduce_prod = P.ReduceProd()
|
|
|
39
40
|
reduce_sum = P.ReduceSum()
|
|
40
41
|
reshape = P.Reshape()
|
|
41
42
|
tile = P.Tile()
|
|
42
|
-
is_sub_class =
|
|
43
|
+
is_sub_class = IsSubClass()
|
|
43
44
|
to_array = P.TupleToArray()
|
|
44
45
|
real_div = P.RealDiv()
|
|
45
46
|
|
|
@@ -56,17 +57,17 @@ def dyn_binop_grad_common(x, y, dx, dy):
|
|
|
56
57
|
dx_origin_dtype = dx.dtype
|
|
57
58
|
if dx_origin_dtype in (mstype.int16, mstype.int32, mstype.int64):
|
|
58
59
|
dx = F.cast(dx, mstype.float32)
|
|
59
|
-
dx =
|
|
60
|
+
dx = sum_grad_reduce_axis(dx, rx)
|
|
60
61
|
dx = F.cast(dx, dx_origin_dtype)
|
|
61
62
|
else:
|
|
62
|
-
dx =
|
|
63
|
+
dx = sum_grad_reduce_axis(dx, rx)
|
|
63
64
|
dy_origin_dtype = dy.dtype
|
|
64
65
|
if dy_origin_dtype in (mstype.int16, mstype.int32, mstype.int64):
|
|
65
66
|
dy = F.cast(dy, mstype.float32)
|
|
66
|
-
dy =
|
|
67
|
+
dy = sum_grad_reduce_axis(dy, ry)
|
|
67
68
|
dy = F.cast(dy, dy_origin_dtype)
|
|
68
69
|
else:
|
|
69
|
-
dy =
|
|
70
|
+
dy = sum_grad_reduce_axis(dy, ry)
|
|
70
71
|
reduce_dx = reshape(dx, shape_of_x)
|
|
71
72
|
reduce_dy = reshape(dy, shape_of_y)
|
|
72
73
|
return reduce_dx, reduce_dy
|
|
@@ -83,8 +84,8 @@ def dyn_binop_grad_common_with_shift(x, y, dx, dy, shift):
|
|
|
83
84
|
broadcast_shape_of_x = shape_of_x[:-shift]
|
|
84
85
|
broadcast_shape_of_y = shape_of_y[:-shift]
|
|
85
86
|
rx, ry = DynamicBroadcastGradientArgs()(broadcast_shape_of_x, broadcast_shape_of_y)
|
|
86
|
-
dx =
|
|
87
|
-
dy =
|
|
87
|
+
dx = sum_grad_reduce_axis(dx, rx)
|
|
88
|
+
dy = sum_grad_reduce_axis(dy, ry)
|
|
88
89
|
reduce_dx = reshape(dx, shape_of_x)
|
|
89
90
|
reduce_dy = reshape(dy, shape_of_y)
|
|
90
91
|
return reduce_dx, reduce_dy
|
|
@@ -111,7 +112,7 @@ def binop_grad_common(x, y, dx, dy):
|
|
|
111
112
|
# if input shape is the same as dout shape, do not need to reduce
|
|
112
113
|
reduce_dx = dx
|
|
113
114
|
reduce_dy = dy
|
|
114
|
-
if not (
|
|
115
|
+
if not (F.is_sequence_value_unknown(shape_of_x) or F.is_sequence_value_unknown(shape_of_y)):
|
|
115
116
|
rx = broadcast_gradient_args(shape_of_x, shape_of_y)
|
|
116
117
|
if rx[0]:
|
|
117
118
|
# if dx is scalar whose shape is (), do not need reduce
|
|
@@ -124,11 +125,12 @@ def binop_grad_common(x, y, dx, dy):
|
|
|
124
125
|
dy = _reduce_sum_with_cast(dy, rx[1])
|
|
125
126
|
reduce_dy = reshape(dy, shape_of_y)
|
|
126
127
|
return reduce_dx, reduce_dy
|
|
127
|
-
|
|
128
|
+
|
|
129
|
+
if not isinstance(shape_of_x, tuple) or not isinstance(shape_of_y, tuple):
|
|
128
130
|
# x or y is scalar
|
|
129
|
-
if not shape_of_x:
|
|
131
|
+
if not isinstance(shape_of_x, tuple):
|
|
130
132
|
reduce_dx = _reduce_sum_with_cast(dx, ())
|
|
131
|
-
if not shape_of_y:
|
|
133
|
+
if not isinstance(shape_of_y, tuple):
|
|
132
134
|
reduce_dy = _reduce_sum_with_cast(dy, ())
|
|
133
135
|
return reduce_dx, reduce_dy
|
|
134
136
|
|
|
@@ -148,7 +150,7 @@ def binop_grad_common_with_shift(x, y, dx, dy, shift):
|
|
|
148
150
|
# if input shape is the same as dout shape, do not need to reduce
|
|
149
151
|
reduce_dx = dx
|
|
150
152
|
reduce_dy = dy
|
|
151
|
-
if not (
|
|
153
|
+
if not (F.is_sequence_value_unknown(broadcast_shape_of_x) or F.is_sequence_value_unknown(broadcast_shape_of_y)):
|
|
152
154
|
rx = broadcast_gradient_args(broadcast_shape_of_x, broadcast_shape_of_y)
|
|
153
155
|
if rx[0]:
|
|
154
156
|
# if dx is scalar whose shape is (), do not need reduce
|
|
@@ -161,49 +163,56 @@ def binop_grad_common_with_shift(x, y, dx, dy, shift):
|
|
|
161
163
|
dy = _reduce_sum_with_cast(dy, rx[1])
|
|
162
164
|
reduce_dy = reshape(dy, shape_of_y)
|
|
163
165
|
return reduce_dx, reduce_dy
|
|
164
|
-
|
|
166
|
+
|
|
167
|
+
if not isinstance(shape_of_x, tuple) or not isinstance(shape_of_y, tuple):
|
|
165
168
|
# x or y is scalar
|
|
166
|
-
if not shape_of_x:
|
|
169
|
+
if not isinstance(shape_of_x, tuple):
|
|
167
170
|
reduce_dx = _reduce_sum_with_cast(dx, ())
|
|
168
|
-
if not shape_of_y:
|
|
171
|
+
if not isinstance(shape_of_y, tuple):
|
|
169
172
|
reduce_dy = _reduce_sum_with_cast(dy, ())
|
|
170
173
|
return reduce_dx, reduce_dy
|
|
171
174
|
|
|
172
175
|
return dyn_binop_grad_common_with_shift(x, y, dx, dy, shift)
|
|
173
176
|
|
|
174
177
|
|
|
175
|
-
def _dyn_reduced_shape(input_shape, axis):
|
|
178
|
+
def _dyn_reduced_shape(input_shape, axis, x):
|
|
176
179
|
"""Dynamic reduce shape"""
|
|
177
180
|
input_shape = P.Cast()(input_shape, ms.int32)
|
|
178
|
-
if
|
|
179
|
-
|
|
180
|
-
expanded_axis = P.ExpandDims()(axis, 1)
|
|
181
|
-
update = P.Cast()(P.OnesLike()(axis), ms.int32)
|
|
182
|
-
return P.TensorScatterUpdate()(input_shape, expanded_axis, update)
|
|
183
|
-
input_rank = P.Rank()(input_shape)
|
|
184
|
-
real_axis = (axis + input_rank) % input_rank
|
|
185
|
-
axis_shape = shape_op(real_axis)
|
|
181
|
+
if x is not None and not F.is_sequence_shape_unknown(shape_op(x)):
|
|
182
|
+
input_rank = len(shape_op(x))
|
|
186
183
|
else:
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
184
|
+
input_rank = dyn_rank(x)
|
|
185
|
+
input_rank = P.Cast()(input_rank, ms.int32)
|
|
186
|
+
|
|
187
|
+
if (isinstance(axis, tuple) and axis == ()) or (isinstance(axis, list) and axis == []):
|
|
188
|
+
res_shape = P.ExpandDims()(input_rank, 0)
|
|
189
|
+
return dyn_ones(res_shape, res_shape.dtype)
|
|
190
|
+
|
|
191
|
+
if isinstance(axis, int):
|
|
192
|
+
axis = (axis,)
|
|
193
|
+
|
|
194
|
+
real_axis = axis
|
|
195
|
+
if not isinstance(axis, Tensor):
|
|
196
|
+
real_axis = Tensor(axis, ms.int32)
|
|
197
|
+
|
|
198
|
+
real_axis = (real_axis + input_rank) % input_rank
|
|
199
|
+
if real_axis.ndim == 0:
|
|
200
|
+
real_axis = P.ExpandDims()(real_axis, 0)
|
|
201
|
+
expanded_axis = P.ExpandDims()(real_axis, 1)
|
|
202
|
+
expanded_axis = P.Cast()(expanded_axis, ms.int32)
|
|
203
|
+
update = P.Cast()(P.OnesLike()(real_axis), ms.float32)
|
|
204
|
+
input_shape = P.Cast()(input_shape, ms.float32)
|
|
205
|
+
return P.TensorScatterUpdate()(input_shape, expanded_axis, update)
|
|
198
206
|
|
|
199
207
|
|
|
200
208
|
def _sum_grad(x, axis, dout):
|
|
201
209
|
"""Grad definition for `Sum` operation."""
|
|
202
210
|
input_shape = shape_op(x)
|
|
203
211
|
is_mutable, axis = convert_to_tensor(axis)
|
|
204
|
-
if
|
|
212
|
+
if F.is_sequence_value_unknown(input_shape) or is_mutable:
|
|
205
213
|
input_shape = dyn_shape_op(x)
|
|
206
|
-
output_shape_kept_dims = _dyn_reduced_shape(input_shape, axis)
|
|
214
|
+
output_shape_kept_dims = _dyn_reduced_shape(input_shape, axis, x)
|
|
215
|
+
output_shape_kept_dims = P.Cast()(output_shape_kept_dims, ms.int32)
|
|
207
216
|
grad = reshape(dout, output_shape_kept_dims)
|
|
208
217
|
return DynamicBroadcastTo()(grad, input_shape)
|
|
209
218
|
|
|
@@ -216,15 +225,40 @@ def _sum_grad(x, axis, dout):
|
|
|
216
225
|
def _min_or_max_grad(x, axis, out, dout):
|
|
217
226
|
"""Grad definition for `Min` and `Max` operations."""
|
|
218
227
|
input_shape = shape_op(x)
|
|
219
|
-
output_shape_kept_dims =
|
|
228
|
+
output_shape_kept_dims = ()
|
|
229
|
+
if F.is_sequence_value_unknown(input_shape):
|
|
230
|
+
input_shape = dyn_shape_op(x)
|
|
231
|
+
output_shape_kept_dims = _dyn_reduced_shape(input_shape, axis, x)
|
|
232
|
+
output_shape_kept_dims = P.Cast()(output_shape_kept_dims, ms.int32)
|
|
233
|
+
else:
|
|
234
|
+
output_shape_kept_dims = reduced_shape(input_shape, axis)
|
|
235
|
+
|
|
220
236
|
y = reshape(out, output_shape_kept_dims)
|
|
221
237
|
grad = reshape(dout, output_shape_kept_dims)
|
|
222
238
|
indicators = F.cast(F.equal(y, x), F.dtype(grad))
|
|
223
|
-
min_num = F.cast(F.
|
|
239
|
+
min_num = F.cast(F.scalar_to_tensor(1e-24), F.dtype(grad))
|
|
224
240
|
num_selected = reshape(reduce_sum(indicators, axis), output_shape_kept_dims) + min_num
|
|
225
241
|
return indicators / num_selected * grad
|
|
226
242
|
|
|
227
243
|
|
|
244
|
+
def _onehot_with_neg_axis(axis, indices, depth, on_value_dtype):
|
|
245
|
+
"""onehot support tensor axis"""
|
|
246
|
+
depth_range = P.Range()(F.cast(0, depth.dtype), depth, F.cast(1, depth.dtype))
|
|
247
|
+
indices_expand = P.ExpandDims()(indices, axis)
|
|
248
|
+
indices_expand_rank = dyn_rank_1d(indices_expand)
|
|
249
|
+
broad_shape = dyn_ones(indices_expand_rank, mstype.int64)
|
|
250
|
+
# It should use int64 dtype, but the TensorScatterUpdate op does not support the int64
|
|
251
|
+
# dtype on Ascend device, so the float32 dtype is used here.
|
|
252
|
+
update_dtype = mstype.float32
|
|
253
|
+
broad_shape = dyn_ones(indices_expand_rank, update_dtype)
|
|
254
|
+
broad_shape[axis] = F.cast(depth, update_dtype)
|
|
255
|
+
broad_shape = F.cast(broad_shape, mstype.int64)
|
|
256
|
+
depth_broad = P.Reshape()(depth_range, broad_shape)
|
|
257
|
+
one_hot_bool = P.Equal()(indices_expand, depth_broad)
|
|
258
|
+
one_hot_res = F.cast(one_hot_bool, on_value_dtype)
|
|
259
|
+
return one_hot_res
|
|
260
|
+
|
|
261
|
+
|
|
228
262
|
def _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout):
|
|
229
263
|
"""ArgMinWiwhValue and ArgMaxWithValue grad."""
|
|
230
264
|
expand = P.ExpandDims()
|
|
@@ -232,53 +266,48 @@ def _argmin_or_argmax_grad(x, axis, keep_dims, op, out, dout):
|
|
|
232
266
|
x_shape = F.shape(x)
|
|
233
267
|
x_dim = len(x_shape)
|
|
234
268
|
x_axis = axis
|
|
269
|
+
onehot_axis_is_neg = False
|
|
235
270
|
if x_axis < 0:
|
|
236
|
-
|
|
271
|
+
if not F.is_sequence_shape_unknown(x_shape):
|
|
272
|
+
x_axis = axis + x_dim
|
|
273
|
+
else:
|
|
274
|
+
onehot_axis_is_neg = True
|
|
237
275
|
onehot_axis = x_axis
|
|
238
|
-
depth = 1
|
|
239
|
-
if x_shape:
|
|
240
|
-
depth = x_shape[axis]
|
|
241
276
|
if keep_dims:
|
|
242
277
|
dout_expand = dout[1]
|
|
243
278
|
out = op(x)
|
|
244
279
|
else:
|
|
245
280
|
dout_expand = expand(dout[1], onehot_axis)
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
281
|
+
out_shape = shape_op(out[0])
|
|
282
|
+
if not F.is_sequence_shape_unknown(out_shape):
|
|
283
|
+
if onehot_axis >= len(out_shape):
|
|
284
|
+
onehot_axis = -1
|
|
249
285
|
type_x = F.dtype(x)
|
|
250
|
-
on_value = F.cast(F.
|
|
251
|
-
off_value = F.cast(F.
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
286
|
+
on_value = F.cast(F.scalar_to_tensor(1.0), type_x)
|
|
287
|
+
off_value = F.cast(F.scalar_to_tensor(0.0), type_x)
|
|
288
|
+
if not F.is_sequence_value_unknown(x_shape):
|
|
289
|
+
depth = 1
|
|
290
|
+
if x_shape:
|
|
291
|
+
depth = x_shape[axis]
|
|
292
|
+
onehot = P.OneHot(onehot_axis)
|
|
293
|
+
dx = dout_expand * onehot(out[0], depth, on_value, off_value)
|
|
294
|
+
if not x_shape:
|
|
295
|
+
dx = squeeze(dx)
|
|
296
|
+
return dx
|
|
297
|
+
x_tensor_shape = P.TensorShape()(x)
|
|
298
|
+
depth = x_tensor_shape[axis]
|
|
299
|
+
if not onehot_axis_is_neg:
|
|
300
|
+
onehot = P.OneHot(onehot_axis)
|
|
301
|
+
dx = dout_expand * onehot(out[0], depth, on_value, off_value)
|
|
302
|
+
else:
|
|
303
|
+
if out[0].value is not None:
|
|
304
|
+
# It is a temporary method: In the pynative mode, out may be a constant tensor. Constant
|
|
305
|
+
# folding occurs in ExpandDims op, but such scenarios are not supported currently.
|
|
306
|
+
out = op(x)
|
|
307
|
+
dx = dout_expand * _onehot_with_neg_axis(onehot_axis, out[0], depth, on_value.dtype)
|
|
255
308
|
return dx
|
|
256
309
|
|
|
257
310
|
|
|
258
|
-
@bprop_getters.register(P.MatMul)
|
|
259
|
-
def bprop_matmul(self):
|
|
260
|
-
"""Grad definition for `MatMul` operation."""
|
|
261
|
-
ta = self.transpose_a
|
|
262
|
-
tb = self.transpose_b
|
|
263
|
-
mul1 = P.MatMul(transpose_a=(ta and tb),
|
|
264
|
-
transpose_b=(ta or (not tb)))
|
|
265
|
-
mul2 = P.MatMul(transpose_a=((not ta) or tb),
|
|
266
|
-
transpose_b=(ta and tb))
|
|
267
|
-
|
|
268
|
-
def bprop(x, w, out, dout):
|
|
269
|
-
if ta:
|
|
270
|
-
dx = mul1(w, dout)
|
|
271
|
-
else:
|
|
272
|
-
dx = mul1(dout, w)
|
|
273
|
-
if tb:
|
|
274
|
-
dw = mul2(dout, x)
|
|
275
|
-
else:
|
|
276
|
-
dw = mul2(x, dout)
|
|
277
|
-
return dx, dw
|
|
278
|
-
|
|
279
|
-
return bprop
|
|
280
|
-
|
|
281
|
-
|
|
282
311
|
@bprop_getters.register(P.BatchMatMul)
|
|
283
312
|
def bprop_batchmatmul(self):
|
|
284
313
|
"""Grad definition for `BatchMatMul` operation."""
|
|
@@ -303,16 +332,6 @@ def bprop_batchmatmul(self):
|
|
|
303
332
|
return bprop
|
|
304
333
|
|
|
305
334
|
|
|
306
|
-
@bprop_getters.register(P.Add)
|
|
307
|
-
def get_bprop_add(self):
|
|
308
|
-
"""Grad definition for `Add` operation."""
|
|
309
|
-
|
|
310
|
-
def bprop(x, y, out, dout):
|
|
311
|
-
return binop_grad_common(x, y, dout, dout)
|
|
312
|
-
|
|
313
|
-
return bprop
|
|
314
|
-
|
|
315
|
-
|
|
316
335
|
@bprop_getters.register(P.TensorAdd)
|
|
317
336
|
def get_bprop_tensor_add(self):
|
|
318
337
|
"""Grad definition for `Add` operation."""
|
|
@@ -339,35 +358,14 @@ def get_bprop_matrix_inverse(self):
|
|
|
339
358
|
return bprop
|
|
340
359
|
|
|
341
360
|
|
|
342
|
-
@bprop_getters.register(P.Neg)
|
|
343
|
-
def get_bprop_neg(self):
|
|
344
|
-
"""Grad definition for `Neg` operation."""
|
|
345
|
-
neg_grad = P.Neg()
|
|
346
|
-
|
|
347
|
-
def bprop(x, out, dout):
|
|
348
|
-
dx = neg_grad(dout)
|
|
349
|
-
return (dx,)
|
|
350
|
-
|
|
351
|
-
return bprop
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
@bprop_getters.register(P.Sub)
|
|
355
|
-
def get_bprop_sub(self):
|
|
356
|
-
"""Grad definition for `Sub` operation."""
|
|
357
|
-
neg_func = P.Neg()
|
|
358
|
-
|
|
359
|
-
def bprop(x, y, out, dout):
|
|
360
|
-
return binop_grad_common(x, y, dout, neg_func(dout))
|
|
361
|
-
|
|
362
|
-
return bprop
|
|
363
|
-
|
|
364
|
-
|
|
365
361
|
@bprop_getters.register(P.Mul)
|
|
366
362
|
def get_bprop_mul(self):
|
|
367
363
|
"""Grad definition for `Mul` operation."""
|
|
368
364
|
mul_func = P.Mul()
|
|
369
365
|
|
|
370
366
|
def bprop(x, y, out, dout):
|
|
367
|
+
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
368
|
+
raise TypeError("For 'Mul', gradient not support for complex type currently.")
|
|
371
369
|
bc_dx = mul_func(y, dout)
|
|
372
370
|
bc_dy = mul_func(x, dout)
|
|
373
371
|
return binop_grad_common(x, y, bc_dx, bc_dy)
|
|
@@ -383,6 +381,8 @@ def get_bprop_real_div(self):
|
|
|
383
381
|
mul_op = P.Mul()
|
|
384
382
|
|
|
385
383
|
def bprop(x, y, out, dout):
|
|
384
|
+
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
385
|
+
raise TypeError("For 'RealDiv', gradient not support for complex type currently.")
|
|
386
386
|
bc_x = div_op(dout, y)
|
|
387
387
|
bc_y = neg(mul_op(bc_x, out))
|
|
388
388
|
return binop_grad_common(x, y, bc_x, bc_y)
|
|
@@ -443,7 +443,10 @@ def get_bprop_floor(self):
|
|
|
443
443
|
dtype_ = P.DType()
|
|
444
444
|
|
|
445
445
|
def bprop(x, out, dout):
|
|
446
|
-
|
|
446
|
+
if F.is_sequence_value_unknown(shape_(x)):
|
|
447
|
+
bc_x = zeros_like(x)
|
|
448
|
+
else:
|
|
449
|
+
bc_x = fill_(dtype_(x), shape_(x), 0.)
|
|
447
450
|
return (bc_x,)
|
|
448
451
|
|
|
449
452
|
return bprop
|
|
@@ -457,7 +460,10 @@ def get_bprop_ceil(self):
|
|
|
457
460
|
dtype_ = P.DType()
|
|
458
461
|
|
|
459
462
|
def bprop(x, out, dout):
|
|
460
|
-
|
|
463
|
+
if F.is_sequence_value_unknown(shape_(x)):
|
|
464
|
+
bc_x = zeros_like(x)
|
|
465
|
+
else:
|
|
466
|
+
bc_x = fill_(dtype_(x), shape_(x), 0.)
|
|
461
467
|
return (bc_x,)
|
|
462
468
|
|
|
463
469
|
return bprop
|
|
@@ -473,6 +479,36 @@ def get_bprop_floordiv(self):
|
|
|
473
479
|
return bprop
|
|
474
480
|
|
|
475
481
|
|
|
482
|
+
@bprop_getters.register(P.BitwiseAnd)
|
|
483
|
+
def get_bprop_bitwiseand(self):
|
|
484
|
+
"""Grad definition for `BitwiseAnd` operation."""
|
|
485
|
+
|
|
486
|
+
def bprop(x, y, out, dout):
|
|
487
|
+
return zeros_like(x), zeros_like(y)
|
|
488
|
+
|
|
489
|
+
return bprop
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
@bprop_getters.register(P.BitwiseOr)
|
|
493
|
+
def get_bprop_bitwiseor(self):
|
|
494
|
+
"""Grad definition for `BitwiseOr` operation."""
|
|
495
|
+
|
|
496
|
+
def bprop(x, y, out, dout):
|
|
497
|
+
return zeros_like(x), zeros_like(y)
|
|
498
|
+
|
|
499
|
+
return bprop
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
@bprop_getters.register(P.BitwiseXor)
|
|
503
|
+
def get_bprop_bitwisexor(self):
|
|
504
|
+
"""Grad definition for `BitwiseXor` operation."""
|
|
505
|
+
|
|
506
|
+
def bprop(x, y, out, dout):
|
|
507
|
+
return zeros_like(x), zeros_like(y)
|
|
508
|
+
|
|
509
|
+
return bprop
|
|
510
|
+
|
|
511
|
+
|
|
476
512
|
@bprop_getters.register(P.FloorMod)
|
|
477
513
|
def get_bprop_floormod(self):
|
|
478
514
|
"""Grad definition for `FloorMod` operation."""
|
|
@@ -529,7 +565,12 @@ def get_bprop_square(self):
|
|
|
529
565
|
|
|
530
566
|
def bprop(x, out, dout):
|
|
531
567
|
temp = mul_func(dout, x)
|
|
532
|
-
|
|
568
|
+
shape_x = shape_op(x)
|
|
569
|
+
if F.is_sequence_value_unknown(shape_x):
|
|
570
|
+
fill_value = dyn_fill(dtype(temp), dyn_shape_op(x), 2.0)
|
|
571
|
+
else:
|
|
572
|
+
fill_value = fill_func(dtype(temp), shape_x, 2.0)
|
|
573
|
+
dx = mul_func(fill_value, temp)
|
|
533
574
|
return (dx,)
|
|
534
575
|
|
|
535
576
|
return bprop
|
|
@@ -575,8 +616,15 @@ def get_bprop_square_sum_all(self):
|
|
|
575
616
|
def bprop(x, y, out, dout):
|
|
576
617
|
temp_x = mul_func(dout[0], x)
|
|
577
618
|
temp_y = mul_func(dout[1], y)
|
|
578
|
-
|
|
579
|
-
|
|
619
|
+
if F.is_sequence_value_unknown(shape_op(x)):
|
|
620
|
+
dx = mul_func(dyn_fill(dtype(temp_x), dyn_shape_op(x), 2.0), temp_x)
|
|
621
|
+
else:
|
|
622
|
+
dx = mul_func(fill_func(dtype(temp_x), shape_op(x), 2.0), temp_x)
|
|
623
|
+
|
|
624
|
+
if F.is_sequence_value_unknown(shape_op(y)):
|
|
625
|
+
dy = mul_func(dyn_fill(dtype(temp_y), dyn_shape_op(y), 2.0), temp_y)
|
|
626
|
+
else:
|
|
627
|
+
dy = mul_func(fill_func(dtype(temp_y), shape_op(y), 2.0), temp_y)
|
|
580
628
|
return (dx, dy)
|
|
581
629
|
|
|
582
630
|
return bprop
|
|
@@ -716,8 +764,14 @@ def get_bprop_pow(self):
|
|
|
716
764
|
ln = P.Log()
|
|
717
765
|
|
|
718
766
|
def bprop(x, power, out, dout):
|
|
767
|
+
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
768
|
+
raise TypeError("For 'Pow', gradient not support for complex type currently.")
|
|
719
769
|
bc_dx = power * pow_op(x, power - 1.0) * dout
|
|
720
|
-
|
|
770
|
+
shape_x = shape_op(x)
|
|
771
|
+
if F.is_sequence_value_unknown(shape_x):
|
|
772
|
+
x = F.select(x < 0, dyn_fill(F.dtype(x), dyn_shape_op(x), 1), x)
|
|
773
|
+
else:
|
|
774
|
+
x = F.select(x < 0, F.fill(F.dtype(x), F.shape(x), 1), x)
|
|
721
775
|
bc_dpower = out * ln(x) * dout
|
|
722
776
|
return binop_grad_common(x, power, bc_dx, bc_dpower)
|
|
723
777
|
|
|
@@ -808,21 +862,31 @@ def get_bprop_cumsum(self):
|
|
|
808
862
|
return bprop
|
|
809
863
|
|
|
810
864
|
|
|
811
|
-
@
|
|
865
|
+
@_primexpr
|
|
812
866
|
def _split_shape_index(input_shape, axis):
|
|
813
867
|
"""Calculate reduce_prod grad transpose indices and perm shape."""
|
|
814
868
|
rank = len(input_shape)
|
|
815
869
|
if isinstance(axis, int):
|
|
816
870
|
axis = tuple([axis])
|
|
817
871
|
reduction_indices = tuple([(i + rank) % rank for i in axis])
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
872
|
+
other_indices_list = []
|
|
873
|
+
for i in range(rank):
|
|
874
|
+
if i not in reduction_indices and i not in other_indices_list:
|
|
875
|
+
other_indices_list.append(i)
|
|
876
|
+
other_indices = tuple(other_indices_list)
|
|
877
|
+
reduced_list = [1] + [input_shape[i] for i in reduction_indices]
|
|
878
|
+
other_list = [1] + [input_shape[i] for i in other_indices]
|
|
879
|
+
reduced_num = 1
|
|
880
|
+
for i in reduced_list:
|
|
881
|
+
reduced_num = reduced_num * i
|
|
882
|
+
other_num = 1
|
|
883
|
+
for i in other_list:
|
|
884
|
+
other_num = other_num * i
|
|
821
885
|
perm = reduction_indices + other_indices
|
|
822
886
|
return tuple([reduced_num, other_num]), perm
|
|
823
887
|
|
|
824
888
|
|
|
825
|
-
@
|
|
889
|
+
@_primexpr
|
|
826
890
|
def _invert_permutation(perm):
|
|
827
891
|
"""Calculate invert permutation."""
|
|
828
892
|
out = [0] * len(perm)
|
|
@@ -831,6 +895,26 @@ def _invert_permutation(perm):
|
|
|
831
895
|
return tuple(out)
|
|
832
896
|
|
|
833
897
|
|
|
898
|
+
def _split_dyn_shape_index(x, axis):
|
|
899
|
+
"""Calculate reduce prod grad invert permutation."""
|
|
900
|
+
input_shape = dyn_shape_op(x)
|
|
901
|
+
rank = dyn_rank(x)
|
|
902
|
+
if not isinstance(axis, Tensor):
|
|
903
|
+
axis = Tensor(axis, dtype=mstype.int64)
|
|
904
|
+
reduction_indices = reshape(axis, (-1,))
|
|
905
|
+
reduction_indices = (reduction_indices + rank) % rank
|
|
906
|
+
reduced = P.Cast()(reduction_indices, mstype.int64)
|
|
907
|
+
|
|
908
|
+
start = Tensor(0, dtype=mstype.int64)
|
|
909
|
+
delta = Tensor(1, dtype=mstype.int64)
|
|
910
|
+
idx = P.Range()(start, rank, delta)
|
|
911
|
+
other, _ = A.ListDiff()(idx, reduced)
|
|
912
|
+
perm = P.Concat()((reduced, other))
|
|
913
|
+
reduced_num = reduce_prod(P.Cast()(P.Gather()(input_shape, reduced, 0), mstype.int64), ())
|
|
914
|
+
other_num = reduce_prod(P.Cast()(P.Gather()(input_shape, other, 0), mstype.int64), ())
|
|
915
|
+
return (reduced_num, other_num), perm
|
|
916
|
+
|
|
917
|
+
|
|
834
918
|
@bprop_getters.register(P.ReduceProd)
|
|
835
919
|
def get_bprop_reduceprod(self):
|
|
836
920
|
"""Grad definition for `ReduceProd` operation."""
|
|
@@ -840,17 +924,35 @@ def get_bprop_reduceprod(self):
|
|
|
840
924
|
|
|
841
925
|
def bprop(x, axis, out, dout):
|
|
842
926
|
"""Grad definition for `Product` operation."""
|
|
927
|
+
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
928
|
+
raise TypeError("The 'ReduceProd', gradient not support for complex type currently.")
|
|
843
929
|
# Expand dout to full input shape
|
|
844
930
|
input_shape = shape_op(x)
|
|
845
|
-
|
|
931
|
+
if input_shape == ():
|
|
932
|
+
dx = _sum_grad(x, axis, dout)
|
|
933
|
+
return dx, zeros_like(axis)
|
|
934
|
+
|
|
935
|
+
if F.is_sequence_value_unknown(input_shape):
|
|
936
|
+
input_shape = dyn_shape_op(x)
|
|
937
|
+
input_shape = P.Cast()(input_shape, ms.int64)
|
|
938
|
+
output_shape_kept_dims = _dyn_reduced_shape(input_shape, axis, x)
|
|
939
|
+
output_shape_kept_dims = P.Cast()(output_shape_kept_dims, ms.int64)
|
|
940
|
+
else:
|
|
941
|
+
output_shape_kept_dims = reduced_shape(input_shape, axis)
|
|
942
|
+
|
|
846
943
|
dout = reshape(dout, output_shape_kept_dims)
|
|
847
|
-
tile_scaling = tuple_div(input_shape, output_shape_kept_dims)
|
|
848
|
-
grad = tile(dout, tile_scaling)
|
|
849
944
|
|
|
850
945
|
# Pack all reduced dimensions into a single one, so we can perform the cumprod ops.
|
|
851
|
-
|
|
946
|
+
if F.is_sequence_value_unknown(shape_op(x)):
|
|
947
|
+
pack_shape, perm = _split_dyn_shape_index(x, axis)
|
|
948
|
+
else:
|
|
949
|
+
pack_shape, perm = _split_shape_index(shape_op(x), axis)
|
|
950
|
+
|
|
852
951
|
permuted = transpose(x, perm)
|
|
853
952
|
permuted_shape = shape_op(permuted)
|
|
953
|
+
if F.is_sequence_value_unknown(permuted_shape):
|
|
954
|
+
permuted_shape = dyn_shape_op(permuted)
|
|
955
|
+
pack_shape = create_tensor_by_element(pack_shape)
|
|
854
956
|
reshaped = reshape(permuted, pack_shape)
|
|
855
957
|
|
|
856
958
|
# Calculate product, leaving out the current entry
|
|
@@ -860,7 +962,14 @@ def get_bprop_reduceprod(self):
|
|
|
860
962
|
|
|
861
963
|
# Invert the transpose and reshape operations.
|
|
862
964
|
# Make sure to set the statically known shape information through a reshape.
|
|
863
|
-
|
|
965
|
+
if F.is_sequence_value_unknown(shape_op(permuted)):
|
|
966
|
+
dout = DynamicBroadcastTo()(dout, input_shape)
|
|
967
|
+
out = transpose(y, dyn_invert_permutation(perm)) * dout
|
|
968
|
+
else:
|
|
969
|
+
tile_scaling = tuple_div(input_shape, output_shape_kept_dims)
|
|
970
|
+
grad = tile(dout, tile_scaling)
|
|
971
|
+
out = transpose(y, _invert_permutation(perm)) * grad
|
|
972
|
+
|
|
864
973
|
dx = reshape(out, input_shape)
|
|
865
974
|
return dx, zeros_like(axis)
|
|
866
975
|
|
|
@@ -908,6 +1017,8 @@ def get_bprop_reducemax(self):
|
|
|
908
1017
|
"""Grad definition for `Max` operation."""
|
|
909
1018
|
|
|
910
1019
|
def bprop(x, axis, out, dout):
|
|
1020
|
+
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
1021
|
+
raise TypeError("The 'ReduceMax', gradient not support for complex type currently.")
|
|
911
1022
|
dx = _min_or_max_grad(x, axis, out, dout)
|
|
912
1023
|
return (dx, zeros_like(axis))
|
|
913
1024
|
|
|
@@ -933,6 +1044,8 @@ def get_bprop_reducemin(self):
|
|
|
933
1044
|
"""Grad definition for `ReduceMin` operation."""
|
|
934
1045
|
|
|
935
1046
|
def bprop(x, axis, out, dout):
|
|
1047
|
+
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
1048
|
+
raise TypeError("The 'ReduceMin', gradient not support for complex type currently.")
|
|
936
1049
|
dx = _min_or_max_grad(x, axis, out, dout)
|
|
937
1050
|
return (dx, zeros_like(axis))
|
|
938
1051
|
|
|
@@ -961,17 +1074,20 @@ def get_bprop_reduce_mean(self):
|
|
|
961
1074
|
dtype = P.DType()
|
|
962
1075
|
|
|
963
1076
|
def bprop(x, axis, out, dout):
|
|
1077
|
+
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
1078
|
+
raise TypeError("The 'ReduceMean', gradient not support for complex type currently.")
|
|
964
1079
|
grad = _sum_grad(x, axis, dout)
|
|
965
1080
|
shape_x = shape_op(x)
|
|
966
1081
|
shape_out = shape_op(out)
|
|
967
|
-
if
|
|
1082
|
+
if F.is_sequence_value_unknown(shape_x) or F.is_sequence_value_unknown(shape_out):
|
|
968
1083
|
shape_x = dyn_shape_op(x)
|
|
969
1084
|
shape_out = dyn_shape_op(out)
|
|
970
|
-
div_shape = reduce_prod(shape_x)
|
|
1085
|
+
div_shape = reduce_prod(cast(shape_x, mstype.float32), ()) /\
|
|
1086
|
+
reduce_prod(cast(shape_out, mstype.float32), ())
|
|
971
1087
|
dx = div_op(grad, cast(div_shape, dtype(grad)))
|
|
972
1088
|
else:
|
|
973
1089
|
div_shape = F.shape_mul(shape_x) / F.shape_mul(shape_out)
|
|
974
|
-
dx = div_op(grad, cast(F.
|
|
1090
|
+
dx = div_op(grad, cast(F.scalar_to_tensor(div_shape), dtype(grad)))
|
|
975
1091
|
return dx, zeros_like(axis)
|
|
976
1092
|
|
|
977
1093
|
return bprop
|
|
@@ -1097,16 +1213,6 @@ def get_bprop_logical_and(self):
|
|
|
1097
1213
|
return bprop
|
|
1098
1214
|
|
|
1099
1215
|
|
|
1100
|
-
@bprop_getters.register(P.LogicalOr)
|
|
1101
|
-
def get_bprop_logical_or(self):
|
|
1102
|
-
"""Grad definition for `LogicalOr` operation."""
|
|
1103
|
-
|
|
1104
|
-
def bprop(x, y, out, dout):
|
|
1105
|
-
return zeros_like(x), zeros_like(y)
|
|
1106
|
-
|
|
1107
|
-
return bprop
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
1216
|
@bprop_getters.register(P.NPUAllocFloatStatus)
|
|
1111
1217
|
def get_bprop_npu_alloc_float_status(self):
|
|
1112
1218
|
"""Grad definition for `NPUAllocFloatStatus` operation."""
|
|
@@ -1304,6 +1410,9 @@ def get_bprop_cosh(self):
|
|
|
1304
1410
|
sinh = P.Sinh()
|
|
1305
1411
|
|
|
1306
1412
|
def bprop(x, out, dout):
|
|
1413
|
+
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
1414
|
+
raise TypeError("The 'Cosh', gradient not support for complex type currently.")
|
|
1415
|
+
|
|
1307
1416
|
dx = sinh(x) * dout
|
|
1308
1417
|
return (dx,)
|
|
1309
1418
|
|
|
@@ -1334,16 +1443,6 @@ def get_bprop_conj(self):
|
|
|
1334
1443
|
return bprop
|
|
1335
1444
|
|
|
1336
1445
|
|
|
1337
|
-
@bprop_getters.register(P.ScalarCast)
|
|
1338
|
-
def get_bprop_scalar_cast(self):
|
|
1339
|
-
"""Generate bprop for ScalarCast"""
|
|
1340
|
-
|
|
1341
|
-
def bprop(x, t, out, dout):
|
|
1342
|
-
return F.scalar_cast(dout, F.typeof(x)), zeros_like(t)
|
|
1343
|
-
|
|
1344
|
-
return bprop
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
1446
|
@bprop_getters.register(P.AccumulateNV2)
|
|
1348
1447
|
def get_bprop_scalar_accumulatenv2(self):
|
|
1349
1448
|
"""Generate bprop for AccumulateNV2"""
|
|
@@ -1457,6 +1556,9 @@ def get_bprop_tan(self):
|
|
|
1457
1556
|
cos = P.Cos()
|
|
1458
1557
|
|
|
1459
1558
|
def bprop(x, out, dout):
|
|
1559
|
+
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
1560
|
+
raise TypeError("For 'Tan', gradient not support for complex type currently.")
|
|
1561
|
+
|
|
1460
1562
|
cosx = cos(x)
|
|
1461
1563
|
secx2 = square(reciprocal(cosx))
|
|
1462
1564
|
dx = secx2 * dout
|
|
@@ -1498,6 +1600,9 @@ def get_bprop_atanh(self):
|
|
|
1498
1600
|
div = P.Div()
|
|
1499
1601
|
|
|
1500
1602
|
def bprop(x, out, dout):
|
|
1603
|
+
if x.dtype in (mstype.complex64, mstype.complex128):
|
|
1604
|
+
raise TypeError("For 'Atanh', gradient not support for complex type currently.")
|
|
1605
|
+
|
|
1501
1606
|
tmp = 1 - power(x, 2)
|
|
1502
1607
|
dx = div(1, tmp) * dout
|
|
1503
1608
|
return (dx,)
|
|
@@ -1537,3 +1642,43 @@ def get_bprop_index_add(self):
|
|
|
1537
1642
|
return dout, zeros_like(indices), gather(dout, indices, _axis)
|
|
1538
1643
|
|
|
1539
1644
|
return bprop
|
|
1645
|
+
|
|
1646
|
+
|
|
1647
|
+
@bprop_getters.register(P.InplaceUpdate)
|
|
1648
|
+
def get_bprop_inplace_update(self):
|
|
1649
|
+
"""Grad definition for `InplaceUpdate` operation."""
|
|
1650
|
+
|
|
1651
|
+
def bprop(x, v, out, dout):
|
|
1652
|
+
return zeros_like(x), zeros_like(v)
|
|
1653
|
+
|
|
1654
|
+
return bprop
|
|
1655
|
+
|
|
1656
|
+
|
|
1657
|
+
@bprop_getters.register(P.InplaceUpdateV2)
|
|
1658
|
+
def get_bprop_inplace_update_v2(self):
|
|
1659
|
+
"""Grad definition for `InplaceUpdateV2` operation."""
|
|
1660
|
+
|
|
1661
|
+
def bprop(x, indices, v, out, dout):
|
|
1662
|
+
return zeros_like(x), zeros_like(indices), zeros_like(v)
|
|
1663
|
+
|
|
1664
|
+
return bprop
|
|
1665
|
+
|
|
1666
|
+
|
|
1667
|
+
@bprop_getters.register(P.InplaceSub)
|
|
1668
|
+
def get_bprop_inplace_sub(self):
|
|
1669
|
+
"""Grad definition for `InplaceSub` operation."""
|
|
1670
|
+
|
|
1671
|
+
def bprop(x, input_v, out, dout):
|
|
1672
|
+
return zeros_like(x), zeros_like(input_v)
|
|
1673
|
+
|
|
1674
|
+
return bprop
|
|
1675
|
+
|
|
1676
|
+
|
|
1677
|
+
@bprop_getters.register(P.InplaceAdd)
|
|
1678
|
+
def get_bprop_inplace_add(self):
|
|
1679
|
+
"""Grad definition for `InplaceAdd` operation."""
|
|
1680
|
+
|
|
1681
|
+
def bprop(x, input_v, out, dout):
|
|
1682
|
+
return zeros_like(x), zeros_like(input_v)
|
|
1683
|
+
|
|
1684
|
+
return bprop
|