mindspore 1.10.0__cp39-cp39-win_amd64.whl → 2.0.0rc1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/ConcurrencyCheck.dll +0 -0
- mindspore/CppBuildInsights.dll +0 -0
- mindspore/CppCoreCheck.dll +0 -0
- mindspore/EnumIndex.dll +0 -0
- mindspore/EspXEngine.dll +0 -0
- mindspore/HResultCheck.dll +0 -0
- mindspore/KernelTraceControl.dll +0 -0
- mindspore/LocalESPC.dll +0 -0
- mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
- mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
- mindspore/VariantClear.dll +0 -0
- mindspore/__init__.py +9 -4
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/builtin_operations.py +32 -4
- mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +17 -2
- mindspore/_extends/parse/parser.py +193 -34
- mindspore/_extends/parse/resources.py +7 -8
- mindspore/_extends/parse/standard_method.py +1780 -435
- mindspore/_extends/parse/trope.py +3 -1
- mindspore/amp.py +53 -58
- mindspore/atlprov.dll +0 -0
- mindspore/boost/adasum.py +3 -2
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +46 -26
- mindspore/boost/dim_reduce.py +6 -5
- mindspore/boost/grad_accumulation.py +2 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/cfgpersist.dll +0 -0
- mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
- mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
- mindspore/common/__init__.py +11 -10
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +57 -0
- mindspore/common/api.py +582 -297
- mindspore/common/dtype.py +66 -18
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +38 -1
- mindspore/common/jit_config.py +25 -13
- mindspore/common/mutable.py +53 -24
- mindspore/common/parameter.py +60 -37
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +927 -0
- mindspore/common/tensor.py +1627 -3900
- mindspore/communication/__init__.py +10 -5
- mindspore/communication/_comm_helper.py +78 -214
- mindspore/communication/_hccl_management.py +2 -1
- mindspore/communication/management.py +136 -47
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +291 -56
- mindspore/d3dcompiler_47.dll +0 -0
- mindspore/dataset/__init__.py +12 -8
- mindspore/dataset/audio/__init__.py +9 -9
- mindspore/dataset/audio/transforms.py +1090 -228
- mindspore/dataset/audio/utils.py +87 -39
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +17 -15
- mindspore/dataset/core/config.py +246 -17
- mindspore/dataset/core/py_util_helpers.py +4 -3
- mindspore/dataset/core/validator_helpers.py +10 -10
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +9 -9
- mindspore/dataset/engine/datasets.py +648 -477
- mindspore/dataset/engine/datasets_audio.py +165 -167
- mindspore/dataset/engine/datasets_standard_format.py +93 -67
- mindspore/dataset/engine/datasets_text.py +492 -342
- mindspore/dataset/engine/datasets_user_defined.py +85 -50
- mindspore/dataset/engine/datasets_vision.py +1224 -699
- mindspore/dataset/engine/graphdata.py +134 -69
- mindspore/dataset/engine/iterators.py +50 -9
- mindspore/dataset/engine/offload.py +52 -31
- mindspore/dataset/engine/samplers.py +27 -24
- mindspore/dataset/engine/serializer_deserializer.py +14 -15
- mindspore/dataset/engine/validators.py +213 -52
- mindspore/dataset/text/__init__.py +10 -8
- mindspore/dataset/text/transforms.py +152 -57
- mindspore/dataset/text/utils.py +98 -49
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +4 -2
- mindspore/dataset/transforms/c_transforms.py +11 -13
- mindspore/dataset/transforms/py_transforms.py +2 -2
- mindspore/dataset/transforms/py_transforms_util.py +10 -0
- mindspore/dataset/transforms/transforms.py +13 -15
- mindspore/dataset/transforms/validators.py +7 -7
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/browse_dataset.py +13 -13
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +8 -7
- mindspore/dataset/vision/c_transforms.py +125 -126
- mindspore/dataset/vision/py_transforms.py +37 -37
- mindspore/dataset/vision/py_transforms_util.py +23 -20
- mindspore/dataset/vision/transforms.py +316 -315
- mindspore/dataset/vision/utils.py +313 -17
- mindspore/dataset/vision/validators.py +6 -6
- mindspore/default_config.py +0 -1
- mindspore/dpcmi.dll +0 -0
- mindspore/{compression → experimental}/__init__.py +6 -5
- mindspore/experimental/map_parameter.py +275 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +70 -9
- mindspore/include/api/delegate.h +8 -1
- mindspore/include/api/dual_abi_helper.h +8 -24
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_group.h +68 -0
- mindspore/include/api/model_parallel_runner.h +17 -17
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +20 -4
- mindspore/include/api/status.h +7 -1
- mindspore/include/api/types.h +25 -21
- mindspore/include/api/visible.h +4 -0
- mindspore/include/c_api/model_c.h +5 -0
- mindspore/include/c_api/status_c.h +1 -1
- mindspore/include/dataset/config.h +1 -1
- mindspore/include/dataset/constants.h +14 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/include/dataset/vision.h +56 -117
- mindspore/include/dataset/vision_lite.h +102 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +28 -28
- mindspore/mindrecord/common/exceptions.py +2 -4
- mindspore/mindrecord/filereader.py +19 -1
- mindspore/mindrecord/filewriter.py +250 -88
- mindspore/mindrecord/mindpage.py +13 -13
- mindspore/mindrecord/shardheader.py +15 -15
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +29 -29
- mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
- mindspore/mindrecord/tools/csv_to_mr.py +4 -4
- mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
- mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/{libmindspore_backend.dll → mindspore_backend.dll} +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +1 -5
- mindspore/nn/cell.py +297 -234
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +17 -42
- mindspore/nn/layer/__init__.py +7 -4
- mindspore/nn/layer/activation.py +131 -88
- mindspore/nn/layer/basic.py +313 -613
- mindspore/nn/layer/channel_shuffle.py +103 -0
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +52 -6
- mindspore/nn/layer/conv.py +112 -43
- mindspore/nn/layer/dense.py +10 -9
- mindspore/nn/layer/embedding.py +36 -34
- mindspore/nn/layer/image.py +123 -27
- mindspore/nn/layer/math.py +108 -107
- mindspore/nn/layer/normalization.py +212 -366
- mindspore/nn/layer/padding.py +370 -42
- mindspore/nn/layer/pooling.py +1443 -219
- mindspore/nn/layer/rnn_cells.py +11 -16
- mindspore/nn/layer/rnns.py +38 -39
- mindspore/nn/layer/thor_layer.py +24 -25
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +9 -6
- mindspore/nn/loss/loss.py +678 -142
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
- mindspore/nn/optim/ada_grad.py +8 -8
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +18 -14
- mindspore/nn/optim/adam.py +429 -87
- mindspore/nn/optim/adamax.py +5 -6
- mindspore/nn/optim/adasum.py +10 -8
- mindspore/nn/optim/asgd.py +7 -7
- mindspore/nn/optim/ftrl.py +81 -11
- mindspore/nn/optim/lamb.py +7 -8
- mindspore/nn/optim/lars.py +4 -4
- mindspore/nn/optim/lazyadam.py +82 -7
- mindspore/nn/optim/momentum.py +8 -7
- mindspore/nn/optim/optimizer.py +19 -10
- mindspore/nn/optim/proximal_ada_grad.py +6 -5
- mindspore/nn/optim/rmsprop.py +3 -3
- mindspore/nn/optim/rprop.py +20 -16
- mindspore/nn/optim/sgd.py +21 -15
- mindspore/nn/optim/thor.py +23 -21
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -6
- mindspore/nn/probability/bijector/invert.py +4 -2
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/__init__.py +6 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
- mindspore/nn/probability/distribution/_utils/utils.py +11 -17
- mindspore/nn/probability/distribution/bernoulli.py +6 -6
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +9 -9
- mindspore/nn/probability/distribution/cauchy.py +8 -8
- mindspore/nn/probability/distribution/distribution.py +12 -6
- mindspore/nn/probability/distribution/exponential.py +5 -5
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +6 -5
- mindspore/nn/probability/distribution/gumbel.py +5 -5
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +0 -1
- mindspore/nn/probability/distribution/logistic.py +4 -5
- mindspore/nn/probability/distribution/normal.py +11 -15
- mindspore/nn/probability/distribution/poisson.py +6 -2
- mindspore/nn/probability/distribution/student_t.py +150 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +5 -5
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +8 -1
- mindspore/nn/wrap/cell_wrapper.py +55 -27
- mindspore/nn/wrap/grad_reducer.py +20 -11
- mindspore/nn/wrap/loss_scale.py +47 -30
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +46 -42
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +26 -19
- mindspore/numpy/utils.py +1 -8
- mindspore/numpy/utils_const.py +112 -62
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -3
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +209 -152
- mindspore/ops/_grad/grad_base.py +55 -17
- mindspore/ops/_grad/grad_clip_ops.py +11 -3
- mindspore/ops/_grad/grad_comm_ops.py +58 -47
- mindspore/ops/_grad/grad_implementations.py +21 -61
- mindspore/ops/_grad/grad_inner_ops.py +48 -6
- mindspore/ops/_grad/grad_math_ops.py +306 -161
- mindspore/ops/_grad/grad_nn_ops.py +192 -181
- mindspore/ops/_grad/grad_other_ops.py +1 -1
- mindspore/ops/_grad/grad_quant_ops.py +5 -5
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +15 -9
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
- mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
- mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
- mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
- mindspore/ops/_op_impl/__init__.py +3 -3
- mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
- mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
- mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/__init__.py +1 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
- mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -608
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/greater.py +2 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
- mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
- mindspore/ops/_op_impl/tbe/slice.py +26 -15
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +3 -2
- mindspore/ops/_register_for_op.py +11 -0
- mindspore/ops/_utils/__init__.py +1 -1
- mindspore/ops/_utils/utils.py +20 -41
- mindspore/ops/_vmap/__init__.py +2 -2
- mindspore/ops/_vmap/vmap_array_ops.py +170 -78
- mindspore/ops/_vmap/vmap_base.py +24 -10
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
- mindspore/ops/_vmap/vmap_image_ops.py +52 -0
- mindspore/ops/_vmap/vmap_math_ops.py +77 -6
- mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
- mindspore/ops/_vmap/vmap_other_ops.py +3 -1
- mindspore/ops/_vmap/vmap_random_ops.py +55 -3
- mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
- mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/__init__.py +1 -4
- mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
- mindspore/ops/composite/__init__.py +12 -13
- mindspore/ops/composite/base.py +261 -254
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +197 -156
- mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
- mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
- mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
- mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
- mindspore/ops/function/__init__.py +323 -8
- mindspore/ops/function/array_func.py +3511 -780
- mindspore/ops/function/clip_func.py +329 -0
- mindspore/ops/function/debug_func.py +6 -6
- mindspore/ops/function/grad/__init__.py +5 -1
- mindspore/ops/function/grad/grad_func.py +736 -65
- mindspore/ops/function/image_func.py +270 -0
- mindspore/ops/function/linalg_func.py +268 -8
- mindspore/ops/function/math_func.py +8032 -3164
- mindspore/ops/function/nn_func.py +5619 -1855
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +11 -10
- mindspore/ops/function/random_func.py +939 -77
- mindspore/ops/function/sparse_func.py +249 -84
- mindspore/ops/function/sparse_unary_func.py +2303 -0
- mindspore/ops/function/spectral_func.py +146 -0
- mindspore/ops/function/vmap_func.py +114 -0
- mindspore/ops/functional.py +182 -254
- mindspore/ops/op_info_register.py +79 -34
- mindspore/ops/operations/__init__.py +210 -118
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +25 -15
- mindspore/ops/operations/_grad_ops.py +447 -322
- mindspore/ops/operations/_inner_ops.py +547 -176
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +29 -27
- mindspore/ops/operations/_ocr_ops.py +11 -11
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_quant_ops.py +186 -101
- mindspore/ops/operations/_rl_inner_ops.py +122 -61
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1047 -0
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +4 -4
- mindspore/ops/operations/array_ops.py +1428 -1226
- mindspore/ops/operations/comm_ops.py +180 -117
- mindspore/ops/operations/control_ops.py +4 -2
- mindspore/ops/operations/custom_ops.py +185 -98
- mindspore/ops/operations/debug_ops.py +92 -54
- mindspore/ops/operations/image_ops.py +406 -211
- mindspore/ops/operations/inner_ops.py +42 -53
- mindspore/ops/operations/linalg_ops.py +32 -29
- mindspore/ops/operations/math_ops.py +2076 -897
- mindspore/ops/operations/nn_ops.py +1282 -1252
- mindspore/ops/operations/other_ops.py +124 -278
- mindspore/ops/operations/random_ops.py +345 -178
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +502 -157
- mindspore/ops/operations/spectral_ops.py +107 -0
- mindspore/ops/primitive.py +192 -15
- mindspore/ops/vm_impl_registry.py +23 -2
- mindspore/parallel/__init__.py +6 -1
- mindspore/parallel/_auto_parallel_context.py +199 -92
- mindspore/parallel/_cell_wrapper.py +4 -2
- mindspore/parallel/_cost_model_context.py +3 -0
- mindspore/parallel/_dp_allreduce_fusion.py +2 -1
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +167 -28
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +9 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
- mindspore/parallel/_utils.py +47 -7
- mindspore/parallel/algo_parameter_config.py +5 -1
- mindspore/parallel/checkpoint_transform.py +329 -0
- mindspore/parallel/shard.py +229 -0
- mindspore/perf_msvcbuildinsights.dll +0 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/util.py +4 -3
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +249 -0
- mindspore/profiler/parser/aicpu_data_parser.py +38 -39
- mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
- mindspore/profiler/parser/base_timeline_generator.py +471 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
- mindspore/profiler/parser/framework_parser.py +42 -16
- mindspore/profiler/parser/hccl_parser.py +158 -158
- mindspore/profiler/parser/hwts_log_parser.py +7 -6
- mindspore/profiler/parser/integrator.py +18 -1579
- mindspore/profiler/parser/minddata_analyzer.py +8 -8
- mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +108 -0
- mindspore/profiler/parser/step_trace_parser.py +1 -1
- mindspore/profiler/profiling.py +396 -194
- mindspore/rewrite/__init__.py +6 -2
- mindspore/rewrite/api/node.py +51 -110
- mindspore/rewrite/api/node_type.py +10 -6
- mindspore/rewrite/api/pattern_engine.py +51 -7
- mindspore/rewrite/api/scoped_value.py +64 -53
- mindspore/rewrite/api/symbol_tree.py +108 -61
- mindspore/rewrite/api/tree_node_helper.py +2 -3
- mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
- mindspore/rewrite/ast_helpers/__init__.py +6 -3
- mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
- mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
- mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
- mindspore/rewrite/ast_transformers/__init__.py +0 -1
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
- mindspore/rewrite/common/__init__.py +2 -0
- mindspore/rewrite/common/event.py +1 -1
- mindspore/rewrite/common/observable.py +1 -1
- mindspore/rewrite/common/observer.py +1 -1
- mindspore/rewrite/common/rewrite_elog.py +35 -0
- mindspore/rewrite/namer.py +2 -2
- mindspore/rewrite/namespace.py +14 -4
- mindspore/rewrite/node.py +161 -13
- mindspore/rewrite/parser.py +0 -1
- mindspore/rewrite/parser_register.py +0 -1
- mindspore/rewrite/parsers/arguments_parser.py +3 -2
- mindspore/rewrite/parsers/assign_parser.py +267 -67
- mindspore/rewrite/parsers/attribute_parser.py +56 -0
- mindspore/rewrite/parsers/class_def_parser.py +191 -108
- mindspore/rewrite/parsers/constant_parser.py +101 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/for_parser.py +28 -15
- mindspore/rewrite/parsers/function_def_parser.py +21 -5
- mindspore/rewrite/parsers/if_parser.py +11 -28
- mindspore/rewrite/parsers/module_parser.py +9 -6
- mindspore/rewrite/parsers/return_parser.py +3 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +322 -109
- mindspore/rewrite/symbol_tree_builder.py +45 -8
- mindspore/rewrite/symbol_tree_dumper.py +0 -1
- mindspore/rewrite/topological_manager.py +1 -2
- mindspore/run_check/_check_version.py +209 -112
- mindspore/run_check/run_check.py +2 -1
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +6 -4
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +321 -50
- mindspore/train/callback/__init__.py +3 -1
- mindspore/train/callback/_backup_and_restore.py +120 -0
- mindspore/train/callback/_callback.py +8 -8
- mindspore/train/callback/_checkpoint.py +12 -9
- mindspore/train/callback/_early_stop.py +13 -7
- mindspore/train/callback/_history.py +8 -8
- mindspore/train/callback/_lambda_callback.py +6 -6
- mindspore/train/callback/_landscape.py +36 -38
- mindspore/train/callback/_loss_monitor.py +12 -6
- mindspore/train/callback/_lr_scheduler_callback.py +2 -4
- mindspore/train/callback/_on_request_exit.py +212 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
- mindspore/train/callback/_summary_collector.py +27 -19
- mindspore/train/callback/_time_monitor.py +13 -7
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +122 -33
- mindspore/train/dataset_helper.py +28 -87
- mindspore/train/loss_scale_manager.py +4 -7
- mindspore/{nn → train}/metrics/__init__.py +20 -20
- mindspore/{nn → train}/metrics/accuracy.py +12 -10
- mindspore/{nn → train}/metrics/auc.py +4 -4
- mindspore/{nn → train}/metrics/bleu_score.py +4 -4
- mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
- mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
- mindspore/{nn → train}/metrics/dice.py +6 -5
- mindspore/{nn → train}/metrics/error.py +7 -5
- mindspore/{nn → train}/metrics/fbeta.py +9 -7
- mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
- mindspore/{nn → train}/metrics/loss.py +4 -3
- mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/metric.py +6 -5
- mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
- mindspore/{nn → train}/metrics/perplexity.py +5 -4
- mindspore/{nn → train}/metrics/precision.py +5 -4
- mindspore/{nn → train}/metrics/recall.py +5 -4
- mindspore/{nn → train}/metrics/roc.py +7 -6
- mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/topk.py +7 -5
- mindspore/train/mind_ir_pb2.py +339 -32
- mindspore/train/model.py +113 -84
- mindspore/train/serialization.py +547 -167
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -12
- mindspore/train/train_thor/convert_utils.py +7 -1
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/train/train_thor/model_thor.py +0 -4
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +901 -660
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -514
- mindspore/compression/quant/qat.py +0 -636
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/libatomic-1.dll +0 -0
- mindspore/libgcc_s_seh-1.dll +0 -0
- mindspore/libgfortran-4.dll +0 -0
- mindspore/libgomp-1.dll +0 -0
- mindspore/libjpeg-62.dll +0 -0
- mindspore/libmindspore.dll +0 -0
- mindspore/libmindspore_common.dll +0 -0
- mindspore/libmindspore_core.dll +0 -0
- mindspore/libmindspore_glog.dll +0 -0
- mindspore/libnnacl.dll +0 -0
- mindspore/libopencv_core452.dll +0 -0
- mindspore/libopencv_imgcodecs452.dll +0 -0
- mindspore/libopencv_imgproc452.dll +0 -0
- mindspore/libquadmath-0.dll +0 -0
- mindspore/libsqlite3.dll +0 -0
- mindspore/libssp-0.dll +0 -0
- mindspore/libstdc++-6.dll +0 -0
- mindspore/libtinyxml2.dll +0 -0
- mindspore/libturbojpeg.dll +0 -0
- mindspore/libwinpthread-1.dll +0 -0
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -138
- mindspore/nn/probability/dpn/vae/vae.py +0 -122
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
- mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
- mindspore/ops/composite/array_ops.py +0 -210
- mindspore/ops/composite/clip_ops.py +0 -238
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/ops/operations/sponge_ops.py +0 -3531
- mindspore/ops/operations/sponge_update_ops.py +0 -2546
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- mindspore/run_check/_check_deps_version.py +0 -84
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -14,16 +14,17 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
|
|
16
16
|
"""The vmap implement of grad operator corresponding to nn_ops."""
|
|
17
|
-
|
|
18
17
|
from __future__ import absolute_import
|
|
18
|
+
|
|
19
19
|
from __future__ import division
|
|
20
20
|
from functools import reduce
|
|
21
21
|
import mindspore.numpy as mnp
|
|
22
22
|
from mindspore.ops.operations import _grad_ops as G
|
|
23
23
|
from mindspore.ops import functional as F
|
|
24
24
|
from mindspore.ops import constexpr
|
|
25
|
+
from mindspore.ops.primitive import _primexpr
|
|
25
26
|
from mindspore.ops.primitive import Primitive
|
|
26
|
-
from mindspore.ops.
|
|
27
|
+
from mindspore.ops.function import _VmapGeneralRule
|
|
27
28
|
from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _raise_value_error, \
|
|
28
29
|
_bdim_at_front, _vmap_clone_prim, _vmap_update_prim_attr, _bdim_at_any, _handle_broadcasting
|
|
29
30
|
|
|
@@ -38,7 +39,7 @@ def get_nll_loss_grad_vmap_rule(prim, axis_size):
|
|
|
38
39
|
2. And weight only support shape as (C,), while total_weight should be a scalar.
|
|
39
40
|
"""
|
|
40
41
|
|
|
41
|
-
@
|
|
42
|
+
@_primexpr
|
|
42
43
|
def _get_reshape_shape(shape, keep_dim=0):
|
|
43
44
|
new_batch_size = reduce(
|
|
44
45
|
lambda x, y: x * y, shape if keep_dim == 0 else shape[:-keep_dim])
|
|
@@ -104,6 +105,7 @@ def get_nll_loss_grad_vmap_rule(prim, axis_size):
|
|
|
104
105
|
return vmap_rule
|
|
105
106
|
|
|
106
107
|
|
|
108
|
+
@vmap_rules_getters.register(G.MaxPoolGrad)
|
|
107
109
|
@vmap_rules_getters.register(G.AvgPoolGrad)
|
|
108
110
|
def get_avg_pool_grad_vmap_rule(prim, axis_size):
|
|
109
111
|
"""VmapRule for `AvgPoolGrad`."""
|
|
@@ -225,11 +227,15 @@ def get_cdist_grad_vmap_rule(prim, axis_size):
|
|
|
225
227
|
return vmap_rule
|
|
226
228
|
|
|
227
229
|
|
|
230
|
+
@vmap_rules_getters.register(G.AdaptiveMaxPool3DGrad)
|
|
228
231
|
@vmap_rules_getters.register(G.AdaptiveMaxPool2DGrad)
|
|
229
232
|
def get_adaptive_avgpool2d_vmap_rule(prim, axis_size):
|
|
230
|
-
"""VmapRule for `AdaptiveMaxPool2DGrad` operation."""
|
|
233
|
+
"""VmapRule for `AdaptiveMaxPool2DGrad` and `AdaptiveMaxPool3DGrad` operation."""
|
|
231
234
|
chw_reverse_index = -3
|
|
232
|
-
|
|
235
|
+
if prim.name == "AdaptiveMaxPool2DGrad":
|
|
236
|
+
hw_reverse_index = -2
|
|
237
|
+
else:
|
|
238
|
+
hw_reverse_index = -3
|
|
233
239
|
|
|
234
240
|
def vmap_rule(ygrad_bdim, x_bdim, max_index_bdim):
|
|
235
241
|
is_all_none, result = vmap_general_preprocess(prim, ygrad_bdim, x_bdim, max_index_bdim)
|
|
@@ -352,7 +358,7 @@ def get_batchnorm_grad_vmap_rule(prim, axis_size):
|
|
|
352
358
|
if is_all_none:
|
|
353
359
|
return result
|
|
354
360
|
if data_format == "NHWC":
|
|
355
|
-
#BatchNormGrad with NHWC format is a GPU backend operation and not supported for now.
|
|
361
|
+
# BatchNormGrad with NHWC format is a GPU backend operation and not supported for now.
|
|
356
362
|
return batchnorm_grad_nhwc_vmap(grad_bdim, x_bdim, scale_bdim, rsv_1_bdim, rsv_2_bdim, rsv_3_bdim)
|
|
357
363
|
grad, grad_dim = grad_bdim
|
|
358
364
|
input_x, input_x_dim = x_bdim
|
|
@@ -392,8 +398,9 @@ def get_batchnorm_grad_vmap_rule(prim, axis_size):
|
|
|
392
398
|
|
|
393
399
|
@vmap_rules_getters.register(G.MaxPoolGradGrad)
|
|
394
400
|
@vmap_rules_getters.register(G.MaxPoolGradGradWithArgmax)
|
|
401
|
+
@vmap_rules_getters.register(G.MaxPoolGradWithArgmaxV2)
|
|
395
402
|
def get_maxpool_grad_grad_vmap_rule(prim, axis_size):
|
|
396
|
-
"""VmapRule for `MaxPoolGradGrad` and `
|
|
403
|
+
"""VmapRule for `MaxPoolGradGrad`, `MaxPoolGradGradWithArgmax` and `MaxPoolGradWithArgmaxV2`."""
|
|
397
404
|
chw_reverse_index = -3
|
|
398
405
|
|
|
399
406
|
def vmap_rule(in0_bdim, in1_bdim, in2_bdim):
|
|
@@ -552,7 +559,7 @@ def get_layernormgrad_vmap_rule(prim, axis_size):
|
|
|
552
559
|
return prim_attr_axis
|
|
553
560
|
return prim_attr_axis + 1
|
|
554
561
|
|
|
555
|
-
@
|
|
562
|
+
@_primexpr
|
|
556
563
|
def get_batch_params_reduce_axes(begin_params_axis, x_shape):
|
|
557
564
|
if begin_params_axis < 0:
|
|
558
565
|
x_rank = len(x_shape)
|
|
@@ -560,7 +567,7 @@ def get_layernormgrad_vmap_rule(prim, axis_size):
|
|
|
560
567
|
batch_params_reduce_axes = tuple(range(1, begin_params_axis))
|
|
561
568
|
return batch_params_reduce_axes
|
|
562
569
|
|
|
563
|
-
@
|
|
570
|
+
@_primexpr
|
|
564
571
|
def get_logical_shape(var_shape):
|
|
565
572
|
return var_shape[1:]
|
|
566
573
|
|
|
@@ -682,3 +689,28 @@ def get_upsample_grad_vmap_rule(prim, axis_size):
|
|
|
682
689
|
out = F.reshape(out, real_out_shape)
|
|
683
690
|
return out, 0
|
|
684
691
|
return vmap_rule
|
|
692
|
+
|
|
693
|
+
|
|
694
|
+
@vmap_rules_getters.register(G.LogSoftmaxGrad)
|
|
695
|
+
def get_log_softmax_vmap_rule(prim, axis_size):
|
|
696
|
+
"""VmapRule for 'LogSoftmaxGrad' operation."""
|
|
697
|
+
if isinstance(prim, str):
|
|
698
|
+
axis = -1
|
|
699
|
+
else:
|
|
700
|
+
axis = prim.axis
|
|
701
|
+
|
|
702
|
+
def vmap_rule(x_bdim, grad_bdim):
|
|
703
|
+
is_all_none, result = vmap_general_preprocess(prim, x_bdim)
|
|
704
|
+
if is_all_none:
|
|
705
|
+
return result
|
|
706
|
+
x, x_dim = x_bdim
|
|
707
|
+
grad, _ = grad_bdim
|
|
708
|
+
x_ndim = F.rank(x) - 1
|
|
709
|
+
|
|
710
|
+
batch_axis = axis + x_ndim if axis < 0 else axis
|
|
711
|
+
batch_axis = batch_axis if batch_axis < x_dim else batch_axis + 1
|
|
712
|
+
|
|
713
|
+
dx = G.LogSoftmaxGrad(axis=batch_axis)(x, grad)
|
|
714
|
+
return dx, x_dim
|
|
715
|
+
|
|
716
|
+
return vmap_rule
|
|
@@ -16,9 +16,12 @@
|
|
|
16
16
|
"""image_ops vmap impl."""
|
|
17
17
|
from __future__ import absolute_import
|
|
18
18
|
|
|
19
|
+
import numpy as np
|
|
20
|
+
from mindspore import Tensor
|
|
19
21
|
from mindspore.ops import functional as F
|
|
20
22
|
from mindspore.ops.operations import _grad_ops as G
|
|
21
23
|
from mindspore.ops.operations import image_ops as IMG
|
|
24
|
+
from mindspore.ops import constexpr
|
|
22
25
|
from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim_at_front, \
|
|
23
26
|
_raise_value_error
|
|
24
27
|
|
|
@@ -83,3 +86,52 @@ def get_resize_grad_dynamic_rule(prim, axis_size):
|
|
|
83
86
|
return out, 0
|
|
84
87
|
|
|
85
88
|
return vmap_rule
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@vmap_rules_getters.register(IMG.CropAndResize)
|
|
92
|
+
def get_crop_and_resize_vmap_rule(prim, axis_size):
|
|
93
|
+
"""VmapRule for `CropAndResize` operation."""
|
|
94
|
+
|
|
95
|
+
@constexpr
|
|
96
|
+
def get_box_indices_offsets(axis_size, batch_size, num_boxes):
|
|
97
|
+
offsets = np.arange(0, axis_size * batch_size, batch_size).astype(np.int32)
|
|
98
|
+
offsets = np.reshape(offsets, (axis_size, 1))
|
|
99
|
+
offsets = np.broadcast_to(offsets, (axis_size, num_boxes))
|
|
100
|
+
return Tensor(offsets)
|
|
101
|
+
|
|
102
|
+
def vmap_rule(x_bdim, boxes_bdim, box_indices_bdim, crop_size_bdim):
|
|
103
|
+
is_all_none, result = vmap_general_preprocess(x_bdim, boxes_bdim, box_indices_bdim, crop_size_bdim)
|
|
104
|
+
if is_all_none:
|
|
105
|
+
return result
|
|
106
|
+
|
|
107
|
+
x, x_dim = x_bdim
|
|
108
|
+
boxes, boxes_dim = boxes_bdim
|
|
109
|
+
box_indices, box_indices_dim = box_indices_bdim
|
|
110
|
+
crop_size, crop_size_dim = crop_size_bdim
|
|
111
|
+
if crop_size_dim is not None:
|
|
112
|
+
_raise_value_error(
|
|
113
|
+
"The axis of `crop_size` in `{}` must be None, but got {}.".format(prim.name, crop_size_dim))
|
|
114
|
+
|
|
115
|
+
boxes = _bdim_at_front(boxes, boxes_dim, axis_size)
|
|
116
|
+
box_indices = _bdim_at_front(box_indices, box_indices_dim, axis_size)
|
|
117
|
+
boxes = F.reshape(boxes, (-1, 4))
|
|
118
|
+
num_boxes = F.shape(box_indices)[-1]
|
|
119
|
+
|
|
120
|
+
if x_dim is None:
|
|
121
|
+
box_indices = F.reshape(box_indices, (-1,))
|
|
122
|
+
out = prim(x, boxes, box_indices, crop_size)
|
|
123
|
+
else:
|
|
124
|
+
x = _bdim_at_front(x, x_dim, axis_size)
|
|
125
|
+
x_shape = F.shape(x)
|
|
126
|
+
x = F.reshape(x, (-1,) + x_shape[2:])
|
|
127
|
+
offsets = get_box_indices_offsets(axis_size, x_shape[1], num_boxes)
|
|
128
|
+
box_indices = F.add(box_indices, offsets)
|
|
129
|
+
box_indices = F.reshape(box_indices, (-1,))
|
|
130
|
+
out = prim(x, boxes, box_indices, crop_size)
|
|
131
|
+
|
|
132
|
+
out_shape = F.shape(out)
|
|
133
|
+
out = F.reshape(out, (-1, num_boxes) + out_shape[1:])
|
|
134
|
+
return out, 0
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
return vmap_rule
|
|
@@ -19,13 +19,13 @@ from __future__ import absolute_import
|
|
|
19
19
|
import mindspore.numpy as mnp
|
|
20
20
|
from mindspore.ops import operations as P
|
|
21
21
|
from mindspore.ops import functional as F
|
|
22
|
-
from mindspore.ops import
|
|
22
|
+
from mindspore.ops.primitive import _primexpr
|
|
23
23
|
from mindspore.common import Tensor
|
|
24
24
|
from mindspore.ops.operations import math_ops
|
|
25
25
|
from mindspore.ops.operations import linalg_ops
|
|
26
26
|
from mindspore.ops.operations import _inner_ops
|
|
27
27
|
from mindspore.ops.primitive import Primitive
|
|
28
|
-
from mindspore.ops.
|
|
28
|
+
from mindspore.ops.function import _VmapGeneralRule
|
|
29
29
|
from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, get_assign_vmap_rule, \
|
|
30
30
|
get_unop_vmap_rule, _raise_value_error, _bdim_at_front, _broadcast_by_axis, _handle_broadcasting, \
|
|
31
31
|
_vmap_clone_prim, _bdim_at_any, _get_reduce_batch_axis, _get_reduce_out_dim
|
|
@@ -33,7 +33,7 @@ from mindspore.ops.operations.math_ops import Bernoulli, BesselI0, BesselI1, Bes
|
|
|
33
33
|
BesselK0, BesselK0e, BesselY0, BesselY1, BesselK1, BesselK1e, Median
|
|
34
34
|
|
|
35
35
|
|
|
36
|
-
@
|
|
36
|
+
@_primexpr
|
|
37
37
|
def _broadcast_shape(nd, x_ndim, x_shape):
|
|
38
38
|
return x_shape + (1,) * (nd - x_ndim)
|
|
39
39
|
|
|
@@ -97,6 +97,35 @@ def get_broadcast_binary_op_vmap_rule(prim, axis_size):
|
|
|
97
97
|
return vmap_rule
|
|
98
98
|
|
|
99
99
|
|
|
100
|
+
@vmap_rules_getters.register(P.Addcdiv)
|
|
101
|
+
@vmap_rules_getters.register(P.Addcmul)
|
|
102
|
+
def get_addcxxx_vmap_rule(prim, axis_size):
|
|
103
|
+
"""VmapRule for addcxxx, such as `Addcdiv` and `Addcmul`."""
|
|
104
|
+
|
|
105
|
+
def vmap_rule(input_data_bdim, x1_bdim, x2_bdim, value_bdim):
|
|
106
|
+
is_all_none, result = vmap_general_preprocess(prim, input_data_bdim, x1_bdim, x2_bdim, value_bdim)
|
|
107
|
+
if is_all_none:
|
|
108
|
+
return result
|
|
109
|
+
|
|
110
|
+
input_data, input_data_dim = input_data_bdim
|
|
111
|
+
x1, x1_dim = x1_bdim
|
|
112
|
+
x2, x2_dim = x2_bdim
|
|
113
|
+
value, value_dim = value_bdim
|
|
114
|
+
if input_data_dim == x1_dim and x1_dim == x2_dim and x2_dim == value_dim:
|
|
115
|
+
out = prim(input_data, x1, x2, value)
|
|
116
|
+
return out, input_data_dim
|
|
117
|
+
|
|
118
|
+
input_data = _bdim_at_front(input_data, input_data_dim, axis_size)
|
|
119
|
+
x1 = _bdim_at_front(x1, x1_dim, axis_size)
|
|
120
|
+
x2 = _bdim_at_front(x2, x2_dim, axis_size)
|
|
121
|
+
value = _bdim_at_front(value, value_dim, axis_size)
|
|
122
|
+
|
|
123
|
+
out = prim(input_data, x1, x2, value)
|
|
124
|
+
return out, 0
|
|
125
|
+
|
|
126
|
+
return vmap_rule
|
|
127
|
+
|
|
128
|
+
|
|
100
129
|
@vmap_rules_getters.register(P.Cdist)
|
|
101
130
|
def get_cdist_vmap_rule(prim, axis_size):
|
|
102
131
|
"""VmapRule for `cdist` operation."""
|
|
@@ -358,6 +387,8 @@ def get_inplace_ops_vmap_rule(prim, axis_size):
|
|
|
358
387
|
@vmap_rules_getters.register(P.ReduceMin)
|
|
359
388
|
@vmap_rules_getters.register(P.ReduceMean)
|
|
360
389
|
@vmap_rules_getters.register(P.ReduceProd)
|
|
390
|
+
@vmap_rules_getters.register(P.ReduceAll)
|
|
391
|
+
@vmap_rules_getters.register(P.ReduceAny)
|
|
361
392
|
def get_reducer_vmap_rule(prim, axis_size):
|
|
362
393
|
"""VmapRule for reduce operations, such as `ReduceSum`."""
|
|
363
394
|
reduce_op_map = {
|
|
@@ -365,7 +396,9 @@ def get_reducer_vmap_rule(prim, axis_size):
|
|
|
365
396
|
"ReduceMax": P.ReduceMax,
|
|
366
397
|
"ReduceMin": P.ReduceMin,
|
|
367
398
|
"ReduceMean": P.ReduceMean,
|
|
368
|
-
"ReduceProd": P.ReduceProd
|
|
399
|
+
"ReduceProd": P.ReduceProd,
|
|
400
|
+
"ReduceAll": P.ReduceAll,
|
|
401
|
+
"ReduceAny": P.ReduceAny,
|
|
369
402
|
}
|
|
370
403
|
|
|
371
404
|
if isinstance(prim, str):
|
|
@@ -403,7 +436,7 @@ def get_median_vmap_rule(prim, axis_size):
|
|
|
403
436
|
axis = prim.axis
|
|
404
437
|
keep_dims = prim.keep_dims
|
|
405
438
|
|
|
406
|
-
@
|
|
439
|
+
@_primexpr
|
|
407
440
|
def trans_axis(axis, rank, dim, keep_dims):
|
|
408
441
|
if axis < 0:
|
|
409
442
|
axis += rank - 1
|
|
@@ -431,7 +464,7 @@ def get_index_add_vmap_rule(prim, axis_size):
|
|
|
431
464
|
"""VmapRule for IndexAdd."""
|
|
432
465
|
axis = prim.axis
|
|
433
466
|
|
|
434
|
-
@
|
|
467
|
+
@_primexpr
|
|
435
468
|
def _get_index_add_batch_axis(axis, x_dim, x_ndim):
|
|
436
469
|
"""get batch_axis for IndexAdd."""
|
|
437
470
|
# case1: batch not exists
|
|
@@ -770,6 +803,44 @@ def get_square_sum_all_vmap_rule(prim, axis_size):
|
|
|
770
803
|
return vmap_rule
|
|
771
804
|
|
|
772
805
|
|
|
806
|
+
@vmap_rules_getters.register(math_ops.FFTWithSize)
|
|
807
|
+
def get_fft_with_size_vmap_rule(prim, axis_size):
|
|
808
|
+
"""VmapRule for `FFTWithSize` operation"""
|
|
809
|
+
if isinstance(prim, str):
|
|
810
|
+
prim_name = prim
|
|
811
|
+
prim = Primitive(prim)
|
|
812
|
+
signal_ndim = 1
|
|
813
|
+
inverse = False
|
|
814
|
+
real = False
|
|
815
|
+
norm = "backward"
|
|
816
|
+
oneside = True
|
|
817
|
+
signal_sizes = ()
|
|
818
|
+
else:
|
|
819
|
+
prim_name = prim.name
|
|
820
|
+
signal_ndim = prim.signal_ndim
|
|
821
|
+
inverse = prim.inverse
|
|
822
|
+
real = prim.real
|
|
823
|
+
norm = prim.norm
|
|
824
|
+
oneside = prim.oneside
|
|
825
|
+
signal_sizes = prim.signal_sizes
|
|
826
|
+
|
|
827
|
+
fft = math_ops.FFTWithSize(signal_ndim, inverse, real, norm, oneside, signal_sizes)
|
|
828
|
+
|
|
829
|
+
def vmap_rule(x_bdim):
|
|
830
|
+
is_all_none, result = vmap_general_preprocess(prim, x_bdim)
|
|
831
|
+
if is_all_none:
|
|
832
|
+
return result
|
|
833
|
+
x, x_dim = x_bdim
|
|
834
|
+
x_ndim = F.rank(x)
|
|
835
|
+
if x_dim < 0 or x_dim >= x_ndim - signal_ndim:
|
|
836
|
+
_raise_value_error("The source axi of `x` in `{} must be`in range of ({} {}), "
|
|
837
|
+
"but got {}.".format(prim_name, 0, x_ndim - signal_ndim, x_dim))
|
|
838
|
+
out = fft(x)
|
|
839
|
+
return (out, x_dim)
|
|
840
|
+
|
|
841
|
+
return vmap_rule
|
|
842
|
+
|
|
843
|
+
|
|
773
844
|
get_assign_vmap_rule = vmap_rules_getters.register(P.AssignAdd)(get_assign_vmap_rule)
|
|
774
845
|
get_assign_vmap_rule = vmap_rules_getters.register(P.AssignSub)(get_assign_vmap_rule)
|
|
775
846
|
|
|
@@ -23,6 +23,7 @@ from mindspore.ops.operations import _grad_ops as G
|
|
|
23
23
|
from mindspore.ops.operations import nn_ops as NN
|
|
24
24
|
from mindspore.ops import functional as F
|
|
25
25
|
from mindspore.ops import constexpr
|
|
26
|
+
from mindspore.ops.primitive import _primexpr
|
|
26
27
|
from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, get_unop_vmap_rule, \
|
|
27
28
|
_bdim_at_any, _bdim_at_front, _bdim_at_back, _handle_broadcasting, get_unary_grad_vmap_rule, _raise_value_error, \
|
|
28
29
|
_vmap_clone_prim, _get_reduce_batch_axis
|
|
@@ -375,7 +376,7 @@ def get_bias_add_vmap_rule(prim, axis_size):
|
|
|
375
376
|
def get_channal_pos_in_x(d_format):
|
|
376
377
|
return d_format.find('C') + 1
|
|
377
378
|
|
|
378
|
-
@
|
|
379
|
+
@_primexpr
|
|
379
380
|
def get_bias_dst_shape(x_shape, n_dims, d_format, has_b_dim: bool):
|
|
380
381
|
pos = get_channal_pos_in_x(d_format)
|
|
381
382
|
|
|
@@ -430,7 +431,7 @@ def get_bias_add_grad_vmap_rule(prim, axis_size):
|
|
|
430
431
|
def get_channal_pos(d_format):
|
|
431
432
|
return d_format.find('C') + 1
|
|
432
433
|
|
|
433
|
-
@
|
|
434
|
+
@_primexpr
|
|
434
435
|
def get_axis_for_reduce(x_shape_rank, data_format):
|
|
435
436
|
channal_pos = get_channal_pos(data_format)
|
|
436
437
|
axis_list = ()
|
|
@@ -1072,24 +1073,24 @@ def get_pad_v3_vmap_rule(prim, axis_size):
|
|
|
1072
1073
|
if is_all_none:
|
|
1073
1074
|
return result
|
|
1074
1075
|
if len(params_bdim) < 2:
|
|
1075
|
-
_raise_value_error("The input params in `
|
|
1076
|
-
"but got {}.".format(
|
|
1076
|
+
_raise_value_error("The input params in `PadV3` must >= 2, "
|
|
1077
|
+
"but got {}.".format(len(params_bdim)))
|
|
1077
1078
|
input_x, input_x_dim = params_bdim[0]
|
|
1078
1079
|
paddings, paddings_dim = params_bdim[1]
|
|
1079
1080
|
values = None
|
|
1080
1081
|
out = None
|
|
1081
1082
|
x = _bdim_at_front(input_x, input_x_dim, axis_size)
|
|
1082
1083
|
if paddings_dim is not None:
|
|
1083
|
-
_raise_value_error("The source axis of `paddings` in `
|
|
1084
|
-
"but got {}.".format(
|
|
1084
|
+
_raise_value_error("The source axis of `paddings` in `PadV3` must be None, "
|
|
1085
|
+
"but got {}.".format(paddings_dim))
|
|
1085
1086
|
if mode == "constant":
|
|
1086
1087
|
if len(params_bdim) != 3:
|
|
1087
|
-
_raise_value_error("The input params in `
|
|
1088
|
-
"but got {}.".format(
|
|
1088
|
+
_raise_value_error("The input params in `PadV3` of constant mode must be 3, "
|
|
1089
|
+
"but got {}.".format(len(params_bdim)))
|
|
1089
1090
|
values, values_dim = params_bdim[2]
|
|
1090
1091
|
if values_dim is not None:
|
|
1091
|
-
_raise_value_error("The source axis of `values_dim` in `
|
|
1092
|
-
"but got {}.".format(
|
|
1092
|
+
_raise_value_error("The source axis of `values_dim` in `PadV3` must be None, "
|
|
1093
|
+
"but got {}.".format(values_dim))
|
|
1093
1094
|
if isinstance(paddings, Tensor):
|
|
1094
1095
|
pad_dim = F.shape(paddings)[0] / pad_pair
|
|
1095
1096
|
else:
|
|
@@ -1101,7 +1102,7 @@ def get_pad_v3_vmap_rule(prim, axis_size):
|
|
|
1101
1102
|
out = prim(x, paddings, values)
|
|
1102
1103
|
else:
|
|
1103
1104
|
out = prim(x, paddings)
|
|
1104
|
-
elif x_ndim
|
|
1105
|
+
elif x_ndim >= input_max_dim:
|
|
1105
1106
|
# reshape to 4 dims
|
|
1106
1107
|
x_shape = F.shape(x)
|
|
1107
1108
|
diff_dim = x_ndim - input_max_dim
|
|
@@ -1118,8 +1119,8 @@ def get_pad_v3_vmap_rule(prim, axis_size):
|
|
|
1118
1119
|
real_out_shape = x_shape[:diff_dim + 1] + out_shape[1:]
|
|
1119
1120
|
out = F.reshape(out, real_out_shape)
|
|
1120
1121
|
else:
|
|
1121
|
-
_raise_value_error("The dim of `input_x` in `
|
|
1122
|
-
"but got {}.".format(
|
|
1122
|
+
_raise_value_error("The dim of `input_x` in `PadV3` must be bigger than {}, "
|
|
1123
|
+
"but got {}.".format(pad_dim, x_ndim))
|
|
1123
1124
|
return out, 0
|
|
1124
1125
|
|
|
1125
1126
|
return vmap_rule
|
|
@@ -1308,6 +1309,60 @@ def get_apply_adam_with_amsgrad_rule(prim, axis_size):
|
|
|
1308
1309
|
return vmap_rule
|
|
1309
1310
|
|
|
1310
1311
|
|
|
1312
|
+
@vmap_rules_getters.register(P.Adam)
|
|
1313
|
+
def get_adam_rule(prim, axis_size):
|
|
1314
|
+
"""VmapRule for `Adam` operation"""
|
|
1315
|
+
if hasattr(prim, "batch_rank"):
|
|
1316
|
+
batch_rank = prim.batch_rank + 1
|
|
1317
|
+
else:
|
|
1318
|
+
batch_rank = 1
|
|
1319
|
+
prim_name = prim.name
|
|
1320
|
+
batch_prim = _vmap_clone_prim(prim)
|
|
1321
|
+
batch_prim.add_prim_attr("batch_rank", batch_rank)
|
|
1322
|
+
|
|
1323
|
+
def vmap_rule(var_bdim, m_bdim, v_bdim, beta1_power_bdim, beta2_power_bdim, lr_bdim, beta1_bdim,
|
|
1324
|
+
beta2_bdim, epsilon_bdim, grad_bdim, u_monad):
|
|
1325
|
+
var, var_dim = var_bdim
|
|
1326
|
+
m, m_dim = m_bdim
|
|
1327
|
+
v, v_dim = v_bdim
|
|
1328
|
+
beta1_power, beta1_power_dim = beta1_power_bdim
|
|
1329
|
+
beta2_power, beta2_power_dim = beta2_power_bdim
|
|
1330
|
+
lr, lr_dim = lr_bdim
|
|
1331
|
+
beta1, beta1_dim = beta1_bdim
|
|
1332
|
+
beta2, beta2_dim = beta2_bdim
|
|
1333
|
+
epsilon, epsilon_dim = epsilon_bdim
|
|
1334
|
+
grad, grad_dim = grad_bdim
|
|
1335
|
+
|
|
1336
|
+
all_dim = [m_dim, v_dim, beta1_power_dim, beta2_power_dim, lr_dim, beta1_dim, beta2_dim, epsilon_dim, grad_dim]
|
|
1337
|
+
if var_dim is None:
|
|
1338
|
+
if any(dim is not None for dim in all_dim):
|
|
1339
|
+
raise ValueError("The source axis of `var` is None, "
|
|
1340
|
+
"but the source axis of `m/v/vhat/beta1_power/beta2_power/lr/beta1/beta2/epsilon grad"
|
|
1341
|
+
" is not None. The execution of operator `{}` cannot be guaranteed.".format(prim_name))
|
|
1342
|
+
out_var, out_m, out_v = prim(
|
|
1343
|
+
var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, u_monad)
|
|
1344
|
+
return ((out_var, None), (out_m, None), (out_v, None))
|
|
1345
|
+
|
|
1346
|
+
if any(dim != 0 for dim in [var_dim, m_dim, v_dim]):
|
|
1347
|
+
raise ValueError("For `{}`, the source axis of `var/m/v` must be 0, "
|
|
1348
|
+
"but get `var`: {}, `m`: {}, `v`: {}".format(prim_name, var_dim,
|
|
1349
|
+
m_dim, v_dim))
|
|
1350
|
+
|
|
1351
|
+
beta1_power = _bdim_at_front(beta1_power, beta1_power_dim, axis_size)
|
|
1352
|
+
beta2_power = _bdim_at_front(beta2_power, beta2_power_dim, axis_size)
|
|
1353
|
+
lr = _bdim_at_front(lr, lr_dim, axis_size)
|
|
1354
|
+
beta1 = _bdim_at_front(beta1, beta1_dim, axis_size)
|
|
1355
|
+
beta2 = _bdim_at_front(beta2, beta2_dim, axis_size)
|
|
1356
|
+
epsilon = _bdim_at_front(epsilon, epsilon_dim, axis_size)
|
|
1357
|
+
grad = _bdim_at_front(grad, grad_dim, axis_size)
|
|
1358
|
+
|
|
1359
|
+
out_var, out_m, out_v = batch_prim(
|
|
1360
|
+
var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, u_monad)
|
|
1361
|
+
return ((out_var, 0), (out_m, 0), (out_v, 0))
|
|
1362
|
+
|
|
1363
|
+
return vmap_rule
|
|
1364
|
+
|
|
1365
|
+
|
|
1311
1366
|
@vmap_rules_getters.register(P.ApplyPowerSign)
|
|
1312
1367
|
def get_apply_power_sign_rule(prim, axis_size):
|
|
1313
1368
|
"""VmapRule for `ApplyPowerSign` operation."""
|
|
@@ -1461,10 +1516,9 @@ def get_adaptive_max_pool_2d_vmap_rule(prim, axis_size):
|
|
|
1461
1516
|
nchw_index = 4
|
|
1462
1517
|
chw_reverse_index = -3
|
|
1463
1518
|
hw_size = 2
|
|
1464
|
-
return_indices = prim.return_indices
|
|
1465
1519
|
output_size = prim.output_size
|
|
1466
1520
|
|
|
1467
|
-
@
|
|
1521
|
+
@_primexpr
|
|
1468
1522
|
def get_output_shape(x_ori_shape, output_size):
|
|
1469
1523
|
if isinstance(output_size, tuple):
|
|
1470
1524
|
h_out, w_out = output_size
|
|
@@ -1499,20 +1553,14 @@ def get_adaptive_max_pool_2d_vmap_rule(prim, axis_size):
|
|
|
1499
1553
|
x_ori_shape = F.shape(x)
|
|
1500
1554
|
x = F.reshape(x, (-1,) + x_ori_shape[chw_reverse_index:])
|
|
1501
1555
|
output_shape = get_output_shape(x_ori_shape, output_size)
|
|
1502
|
-
if return_indices:
|
|
1503
|
-
out, indices = prim(x)
|
|
1504
|
-
out = F.reshape(out, output_shape)
|
|
1505
|
-
indices = F.reshape(indices, output_shape)
|
|
1506
|
-
return (out, 0), (indices, 0)
|
|
1507
|
-
out = prim(x)
|
|
1508
|
-
out = F.reshape(out, output_shape)
|
|
1509
|
-
return out, 0
|
|
1510
|
-
# for the case of CHW
|
|
1511
|
-
if return_indices:
|
|
1512
1556
|
out, indices = prim(x)
|
|
1557
|
+
out = F.reshape(out, output_shape)
|
|
1558
|
+
indices = F.reshape(indices, output_shape)
|
|
1513
1559
|
return (out, 0), (indices, 0)
|
|
1514
|
-
|
|
1515
|
-
|
|
1560
|
+
|
|
1561
|
+
# for the case of CHW
|
|
1562
|
+
out, indices = prim(x)
|
|
1563
|
+
return (out, 0), (indices, 0)
|
|
1516
1564
|
|
|
1517
1565
|
return vmap_rule
|
|
1518
1566
|
|
|
@@ -1657,6 +1705,7 @@ def get_apply_centered_rmsprop_vmap_rule(prim, axis_size):
|
|
|
1657
1705
|
|
|
1658
1706
|
@vmap_rules_getters.register(P.MaxPool)
|
|
1659
1707
|
@vmap_rules_getters.register(P.MaxPoolWithArgmax)
|
|
1708
|
+
@vmap_rules_getters.register(P.MaxPoolWithArgmaxV2)
|
|
1660
1709
|
def get_max_pool_vmap_rule(prim, axis_size):
|
|
1661
1710
|
"""VmapRule for `MaxPool` operation."""
|
|
1662
1711
|
if isinstance(prim, str):
|
|
@@ -1664,7 +1713,7 @@ def get_max_pool_vmap_rule(prim, axis_size):
|
|
|
1664
1713
|
|
|
1665
1714
|
prim_name = prim.name
|
|
1666
1715
|
|
|
1667
|
-
@
|
|
1716
|
+
@_primexpr
|
|
1668
1717
|
def get_original_shape(x_shape, out_shape):
|
|
1669
1718
|
h_new = out_shape[2]
|
|
1670
1719
|
w_new = out_shape[3]
|
|
@@ -1709,7 +1758,7 @@ def get_layernorm_vmap_rule(prim, axis_size):
|
|
|
1709
1758
|
params_axis = process_attr_axis(prim.begin_params_axis)
|
|
1710
1759
|
batch_prim = P.LayerNorm(norm_axis, params_axis, prim.epsilon)
|
|
1711
1760
|
|
|
1712
|
-
@
|
|
1761
|
+
@_primexpr
|
|
1713
1762
|
def get_logical_shape(var_shape):
|
|
1714
1763
|
return var_shape[1:]
|
|
1715
1764
|
|
|
@@ -83,7 +83,9 @@ def get_partical_vmap_rule(prim, axis_size):
|
|
|
83
83
|
else:
|
|
84
84
|
val, dim = val_bdim
|
|
85
85
|
if dim is not None:
|
|
86
|
-
_raise_value_error("
|
|
86
|
+
_raise_value_error("In the scenario where vmap contains control flow, currently only the "
|
|
87
|
+
"case of each batch branch with the same processing operations is "
|
|
88
|
+
"supported, so that the source axis of args in {} must be None, "
|
|
87
89
|
"but got {}.".format(prim_name, dim))
|
|
88
90
|
vals = vals + (val,)
|
|
89
91
|
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
|
|
2
1
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
3
2
|
#
|
|
4
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -17,8 +16,11 @@
|
|
|
17
16
|
"""random_ops vmap impl."""
|
|
18
17
|
from __future__ import absolute_import
|
|
19
18
|
|
|
20
|
-
from mindspore.ops.operations.random_ops import UniformCandidateSampler, RandomShuffle
|
|
21
|
-
|
|
19
|
+
from mindspore.ops.operations.random_ops import UniformCandidateSampler, RandomShuffle, Multinomial, \
|
|
20
|
+
RandomChoiceWithMask
|
|
21
|
+
from mindspore.ops.function import _VmapGeneralRule
|
|
22
|
+
from mindspore.ops._vmap.vmap_base import vmap_rules_getters, _bdim_at_front, _vmap_clone_prim, \
|
|
23
|
+
vmap_general_preprocess, _raise_value_error
|
|
22
24
|
|
|
23
25
|
|
|
24
26
|
@vmap_rules_getters.register(UniformCandidateSampler)
|
|
@@ -68,3 +70,53 @@ def get_random_shuffle_vmap_rule(prim, axis_size):
|
|
|
68
70
|
return out, 0
|
|
69
71
|
|
|
70
72
|
return vmap_rule
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@vmap_rules_getters.register(Multinomial)
|
|
76
|
+
def get_multinomial_vmap_rule(prim, axis_size):
|
|
77
|
+
"""VmapRule for `Multinomial` operation."""
|
|
78
|
+
prim_name = prim.name
|
|
79
|
+
prim_vmap = _VmapGeneralRule(prim, axis_size)
|
|
80
|
+
|
|
81
|
+
def vmap_rule(x_bdim, num_samples_bdim):
|
|
82
|
+
is_all_none, result = vmap_general_preprocess(
|
|
83
|
+
prim, x_bdim, num_samples_bdim)
|
|
84
|
+
if is_all_none:
|
|
85
|
+
return result
|
|
86
|
+
|
|
87
|
+
x, x_dim = x_bdim
|
|
88
|
+
num_samples, num_samples_dim = num_samples_bdim
|
|
89
|
+
if len(x.shape) > 2:
|
|
90
|
+
out = prim_vmap(x_bdim, num_samples_bdim)
|
|
91
|
+
return out
|
|
92
|
+
if num_samples_dim is not None:
|
|
93
|
+
_raise_value_error("The source axis of args in {} must be None, "
|
|
94
|
+
"but got {}.".format(prim_name, num_samples_dim))
|
|
95
|
+
x = _bdim_at_front(x, x_dim, axis_size)
|
|
96
|
+
out = prim(x, num_samples)
|
|
97
|
+
return (out, 0)
|
|
98
|
+
|
|
99
|
+
return vmap_rule
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@vmap_rules_getters.register(RandomChoiceWithMask)
|
|
103
|
+
def get_random_choice_with_mask(prim, axis_size):
|
|
104
|
+
"""VmapRule for 'RandomChoiceWithMask' operation."""
|
|
105
|
+
if hasattr(prim, 'batch_rank'):
|
|
106
|
+
batch_rank = prim.batch_rank + 1
|
|
107
|
+
else:
|
|
108
|
+
batch_rank = 1
|
|
109
|
+
|
|
110
|
+
batch_prim = _vmap_clone_prim(prim)
|
|
111
|
+
batch_prim.add_prim_attr('batch_rank', batch_rank)
|
|
112
|
+
|
|
113
|
+
def vmap_rule(x_bdim):
|
|
114
|
+
is_all_none, result = vmap_general_preprocess(prim, x_bdim)
|
|
115
|
+
if is_all_none:
|
|
116
|
+
return result
|
|
117
|
+
x_data, x_dim = x_bdim
|
|
118
|
+
x = _bdim_at_front(x_data, x_dim, axis_size)
|
|
119
|
+
index, mask = batch_prim(x)
|
|
120
|
+
return (index, 0), (mask, 0)
|
|
121
|
+
|
|
122
|
+
return vmap_rule
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
|
|
16
16
|
"""sparse_ops vmap impl."""
|
|
17
|
+
from __future__ import absolute_import
|
|
17
18
|
|
|
18
19
|
from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _raise_value_error
|
|
19
20
|
from mindspore.ops.primitive import Primitive
|
|
Binary file
|
|
Binary file
|
|
@@ -1,20 +1,19 @@
|
|
|
1
1
|
|
|
2
|
-
0.1.1 MindSpore*
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
bprop.
|
|
14
|
-
|
|
15
|
-
bprop.
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
S-Prim-MakeTuple:5S-Prim-
|
|
20
|
-
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
2
|
+
0.1.1 MindSpore*2.0.0:�
|
|
3
|
+
�.get_bprop_approximate_equal.1184:[CNode]1185:1.get_bprop_approximate_equal.1184:[CNode]1185:1"REF::bprop.1186:Default/bprop.1186-op927 get_bprop_approximate_equal.1184*'
|
|
4
|
+
%get_bprop_approximate_equal.1184:self*$
|
|
5
|
+
"get_bprop_approximate_equal.1184:x*$
|
|
6
|
+
"get_bprop_approximate_equal.1184:y*&
|
|
7
|
+
$get_bprop_approximate_equal.1184:out*'
|
|
8
|
+
%get_bprop_approximate_equal.1184:dout20
|
|
9
|
+
.get_bprop_approximate_equal.1184:[CNode]1185:1:@7fb54a66e55c2c40cd92783044880b792666a0d7fc794bb717bee3544337d6a0J/grad_math_ops.pyB�
|
|
10
|
+
�
|
|
11
|
+
"get_bprop_approximate_equal.1184:xbprop.1186:[CNode]1187:2bprop.1186:[CNode]1187:2".REF::MetaFuncGraph::hyper_map[zeros_like_leaf]:/Default/S-Prim-hyper_map[zeros_like_leaf]-op928
|
|
12
|
+
�
|
|
13
|
+
"get_bprop_approximate_equal.1184:ybprop.1186:[CNode]1188:3bprop.1186:[CNode]1188:3".REF::MetaFuncGraph::hyper_map[zeros_like_leaf]:/Default/S-Prim-hyper_map[zeros_like_leaf]-op929
|
|
14
|
+
�
|
|
15
|
+
bprop.1186:[CNode]1187:2
|
|
16
|
+
bprop.1186:[CNode]1188:3bprop.1186:[CNode]1189:4bprop.1186:[CNode]1189:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op930
|
|
17
|
+
bprop.11862
|
|
18
|
+
bprop.1186:[CNode]1189:4Pb&
|
|
19
|
+
S-Prim-MakeTuple:5S-Prim-MakeTupleh
|