mindspore 1.10.0__cp38-cp38-win_amd64.whl → 2.0.0rc1__cp38-cp38-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/ConcurrencyCheck.dll +0 -0
- mindspore/CppBuildInsights.dll +0 -0
- mindspore/CppCoreCheck.dll +0 -0
- mindspore/EnumIndex.dll +0 -0
- mindspore/EspXEngine.dll +0 -0
- mindspore/HResultCheck.dll +0 -0
- mindspore/KernelTraceControl.dll +0 -0
- mindspore/LocalESPC.dll +0 -0
- mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
- mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
- mindspore/VariantClear.dll +0 -0
- mindspore/__init__.py +9 -4
- mindspore/_c_dataengine.cp38-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp38-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp38-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/builtin_operations.py +32 -4
- mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +17 -2
- mindspore/_extends/parse/parser.py +193 -34
- mindspore/_extends/parse/resources.py +7 -8
- mindspore/_extends/parse/standard_method.py +1780 -435
- mindspore/_extends/parse/trope.py +3 -1
- mindspore/amp.py +53 -58
- mindspore/atlprov.dll +0 -0
- mindspore/boost/adasum.py +3 -2
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +46 -26
- mindspore/boost/dim_reduce.py +6 -5
- mindspore/boost/grad_accumulation.py +2 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/cfgpersist.dll +0 -0
- mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
- mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
- mindspore/common/__init__.py +11 -10
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +57 -0
- mindspore/common/api.py +582 -297
- mindspore/common/dtype.py +66 -18
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +38 -1
- mindspore/common/jit_config.py +25 -13
- mindspore/common/mutable.py +53 -24
- mindspore/common/parameter.py +60 -37
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +927 -0
- mindspore/common/tensor.py +1627 -3900
- mindspore/communication/__init__.py +10 -5
- mindspore/communication/_comm_helper.py +78 -214
- mindspore/communication/_hccl_management.py +2 -1
- mindspore/communication/management.py +136 -47
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +291 -56
- mindspore/d3dcompiler_47.dll +0 -0
- mindspore/dataset/__init__.py +12 -8
- mindspore/dataset/audio/__init__.py +9 -9
- mindspore/dataset/audio/transforms.py +1090 -228
- mindspore/dataset/audio/utils.py +87 -39
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +17 -15
- mindspore/dataset/core/config.py +246 -17
- mindspore/dataset/core/py_util_helpers.py +4 -3
- mindspore/dataset/core/validator_helpers.py +10 -10
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +9 -9
- mindspore/dataset/engine/datasets.py +648 -477
- mindspore/dataset/engine/datasets_audio.py +165 -167
- mindspore/dataset/engine/datasets_standard_format.py +93 -67
- mindspore/dataset/engine/datasets_text.py +492 -342
- mindspore/dataset/engine/datasets_user_defined.py +85 -50
- mindspore/dataset/engine/datasets_vision.py +1224 -699
- mindspore/dataset/engine/graphdata.py +134 -69
- mindspore/dataset/engine/iterators.py +50 -9
- mindspore/dataset/engine/offload.py +52 -31
- mindspore/dataset/engine/samplers.py +27 -24
- mindspore/dataset/engine/serializer_deserializer.py +14 -15
- mindspore/dataset/engine/validators.py +213 -52
- mindspore/dataset/text/__init__.py +10 -8
- mindspore/dataset/text/transforms.py +152 -57
- mindspore/dataset/text/utils.py +98 -49
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +4 -2
- mindspore/dataset/transforms/c_transforms.py +11 -13
- mindspore/dataset/transforms/py_transforms.py +2 -2
- mindspore/dataset/transforms/py_transforms_util.py +10 -0
- mindspore/dataset/transforms/transforms.py +13 -15
- mindspore/dataset/transforms/validators.py +7 -7
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/browse_dataset.py +13 -13
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +8 -7
- mindspore/dataset/vision/c_transforms.py +125 -126
- mindspore/dataset/vision/py_transforms.py +37 -37
- mindspore/dataset/vision/py_transforms_util.py +23 -20
- mindspore/dataset/vision/transforms.py +316 -315
- mindspore/dataset/vision/utils.py +313 -17
- mindspore/dataset/vision/validators.py +6 -6
- mindspore/default_config.py +0 -1
- mindspore/dpcmi.dll +0 -0
- mindspore/{compression → experimental}/__init__.py +6 -5
- mindspore/experimental/map_parameter.py +275 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +70 -9
- mindspore/include/api/delegate.h +8 -1
- mindspore/include/api/dual_abi_helper.h +8 -24
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_group.h +68 -0
- mindspore/include/api/model_parallel_runner.h +17 -17
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +20 -4
- mindspore/include/api/status.h +7 -1
- mindspore/include/api/types.h +25 -21
- mindspore/include/api/visible.h +4 -0
- mindspore/include/c_api/model_c.h +5 -0
- mindspore/include/c_api/status_c.h +1 -1
- mindspore/include/dataset/config.h +1 -1
- mindspore/include/dataset/constants.h +14 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/include/dataset/vision.h +56 -117
- mindspore/include/dataset/vision_lite.h +102 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +28 -28
- mindspore/mindrecord/common/exceptions.py +2 -4
- mindspore/mindrecord/filereader.py +19 -1
- mindspore/mindrecord/filewriter.py +250 -88
- mindspore/mindrecord/mindpage.py +13 -13
- mindspore/mindrecord/shardheader.py +15 -15
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +29 -29
- mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
- mindspore/mindrecord/tools/csv_to_mr.py +4 -4
- mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
- mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/{libmindspore_backend.dll → mindspore_backend.dll} +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +1 -5
- mindspore/nn/cell.py +297 -234
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +17 -42
- mindspore/nn/layer/__init__.py +7 -4
- mindspore/nn/layer/activation.py +131 -88
- mindspore/nn/layer/basic.py +313 -613
- mindspore/nn/layer/channel_shuffle.py +103 -0
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +52 -6
- mindspore/nn/layer/conv.py +112 -43
- mindspore/nn/layer/dense.py +10 -9
- mindspore/nn/layer/embedding.py +36 -34
- mindspore/nn/layer/image.py +123 -27
- mindspore/nn/layer/math.py +108 -107
- mindspore/nn/layer/normalization.py +212 -366
- mindspore/nn/layer/padding.py +370 -42
- mindspore/nn/layer/pooling.py +1443 -219
- mindspore/nn/layer/rnn_cells.py +11 -16
- mindspore/nn/layer/rnns.py +38 -39
- mindspore/nn/layer/thor_layer.py +24 -25
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +9 -6
- mindspore/nn/loss/loss.py +678 -142
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
- mindspore/nn/optim/ada_grad.py +8 -8
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +18 -14
- mindspore/nn/optim/adam.py +429 -87
- mindspore/nn/optim/adamax.py +5 -6
- mindspore/nn/optim/adasum.py +10 -8
- mindspore/nn/optim/asgd.py +7 -7
- mindspore/nn/optim/ftrl.py +81 -11
- mindspore/nn/optim/lamb.py +7 -8
- mindspore/nn/optim/lars.py +4 -4
- mindspore/nn/optim/lazyadam.py +82 -7
- mindspore/nn/optim/momentum.py +8 -7
- mindspore/nn/optim/optimizer.py +19 -10
- mindspore/nn/optim/proximal_ada_grad.py +6 -5
- mindspore/nn/optim/rmsprop.py +3 -3
- mindspore/nn/optim/rprop.py +20 -16
- mindspore/nn/optim/sgd.py +21 -15
- mindspore/nn/optim/thor.py +23 -21
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -6
- mindspore/nn/probability/bijector/invert.py +4 -2
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/__init__.py +6 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
- mindspore/nn/probability/distribution/_utils/utils.py +11 -17
- mindspore/nn/probability/distribution/bernoulli.py +6 -6
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +9 -9
- mindspore/nn/probability/distribution/cauchy.py +8 -8
- mindspore/nn/probability/distribution/distribution.py +12 -6
- mindspore/nn/probability/distribution/exponential.py +5 -5
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +6 -5
- mindspore/nn/probability/distribution/gumbel.py +5 -5
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +0 -1
- mindspore/nn/probability/distribution/logistic.py +4 -5
- mindspore/nn/probability/distribution/normal.py +11 -15
- mindspore/nn/probability/distribution/poisson.py +6 -2
- mindspore/nn/probability/distribution/student_t.py +150 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +5 -5
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +8 -1
- mindspore/nn/wrap/cell_wrapper.py +55 -27
- mindspore/nn/wrap/grad_reducer.py +20 -11
- mindspore/nn/wrap/loss_scale.py +47 -30
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +46 -42
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +26 -19
- mindspore/numpy/utils.py +1 -8
- mindspore/numpy/utils_const.py +112 -62
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -3
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +209 -152
- mindspore/ops/_grad/grad_base.py +55 -17
- mindspore/ops/_grad/grad_clip_ops.py +11 -3
- mindspore/ops/_grad/grad_comm_ops.py +58 -47
- mindspore/ops/_grad/grad_implementations.py +21 -61
- mindspore/ops/_grad/grad_inner_ops.py +48 -6
- mindspore/ops/_grad/grad_math_ops.py +306 -161
- mindspore/ops/_grad/grad_nn_ops.py +192 -181
- mindspore/ops/_grad/grad_other_ops.py +1 -1
- mindspore/ops/_grad/grad_quant_ops.py +5 -5
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +15 -9
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
- mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
- mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
- mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
- mindspore/ops/_op_impl/__init__.py +3 -3
- mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
- mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
- mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/__init__.py +1 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
- mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -608
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/greater.py +2 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
- mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
- mindspore/ops/_op_impl/tbe/slice.py +26 -15
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +3 -2
- mindspore/ops/_register_for_op.py +11 -0
- mindspore/ops/_utils/__init__.py +1 -1
- mindspore/ops/_utils/utils.py +20 -41
- mindspore/ops/_vmap/__init__.py +2 -2
- mindspore/ops/_vmap/vmap_array_ops.py +170 -78
- mindspore/ops/_vmap/vmap_base.py +24 -10
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
- mindspore/ops/_vmap/vmap_image_ops.py +52 -0
- mindspore/ops/_vmap/vmap_math_ops.py +77 -6
- mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
- mindspore/ops/_vmap/vmap_other_ops.py +3 -1
- mindspore/ops/_vmap/vmap_random_ops.py +55 -3
- mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
- mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/__init__.py +1 -4
- mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
- mindspore/ops/composite/__init__.py +12 -13
- mindspore/ops/composite/base.py +261 -254
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +197 -156
- mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
- mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
- mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
- mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
- mindspore/ops/function/__init__.py +323 -8
- mindspore/ops/function/array_func.py +3511 -780
- mindspore/ops/function/clip_func.py +329 -0
- mindspore/ops/function/debug_func.py +6 -6
- mindspore/ops/function/grad/__init__.py +5 -1
- mindspore/ops/function/grad/grad_func.py +736 -65
- mindspore/ops/function/image_func.py +270 -0
- mindspore/ops/function/linalg_func.py +268 -8
- mindspore/ops/function/math_func.py +8032 -3164
- mindspore/ops/function/nn_func.py +5619 -1855
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +11 -10
- mindspore/ops/function/random_func.py +939 -77
- mindspore/ops/function/sparse_func.py +249 -84
- mindspore/ops/function/sparse_unary_func.py +2303 -0
- mindspore/ops/function/spectral_func.py +146 -0
- mindspore/ops/function/vmap_func.py +114 -0
- mindspore/ops/functional.py +182 -254
- mindspore/ops/op_info_register.py +79 -34
- mindspore/ops/operations/__init__.py +210 -118
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +25 -15
- mindspore/ops/operations/_grad_ops.py +447 -322
- mindspore/ops/operations/_inner_ops.py +547 -176
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +29 -27
- mindspore/ops/operations/_ocr_ops.py +11 -11
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_quant_ops.py +186 -101
- mindspore/ops/operations/_rl_inner_ops.py +122 -61
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1047 -0
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +4 -4
- mindspore/ops/operations/array_ops.py +1428 -1226
- mindspore/ops/operations/comm_ops.py +180 -117
- mindspore/ops/operations/control_ops.py +4 -2
- mindspore/ops/operations/custom_ops.py +185 -98
- mindspore/ops/operations/debug_ops.py +92 -54
- mindspore/ops/operations/image_ops.py +406 -211
- mindspore/ops/operations/inner_ops.py +42 -53
- mindspore/ops/operations/linalg_ops.py +32 -29
- mindspore/ops/operations/math_ops.py +2076 -897
- mindspore/ops/operations/nn_ops.py +1282 -1252
- mindspore/ops/operations/other_ops.py +124 -278
- mindspore/ops/operations/random_ops.py +345 -178
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +502 -157
- mindspore/ops/operations/spectral_ops.py +107 -0
- mindspore/ops/primitive.py +192 -15
- mindspore/ops/vm_impl_registry.py +23 -2
- mindspore/parallel/__init__.py +6 -1
- mindspore/parallel/_auto_parallel_context.py +199 -92
- mindspore/parallel/_cell_wrapper.py +4 -2
- mindspore/parallel/_cost_model_context.py +3 -0
- mindspore/parallel/_dp_allreduce_fusion.py +2 -1
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +167 -28
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +9 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
- mindspore/parallel/_utils.py +47 -7
- mindspore/parallel/algo_parameter_config.py +5 -1
- mindspore/parallel/checkpoint_transform.py +329 -0
- mindspore/parallel/shard.py +229 -0
- mindspore/perf_msvcbuildinsights.dll +0 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/util.py +4 -3
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +249 -0
- mindspore/profiler/parser/aicpu_data_parser.py +38 -39
- mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
- mindspore/profiler/parser/base_timeline_generator.py +471 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
- mindspore/profiler/parser/framework_parser.py +42 -16
- mindspore/profiler/parser/hccl_parser.py +158 -158
- mindspore/profiler/parser/hwts_log_parser.py +7 -6
- mindspore/profiler/parser/integrator.py +18 -1579
- mindspore/profiler/parser/minddata_analyzer.py +8 -8
- mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +108 -0
- mindspore/profiler/parser/step_trace_parser.py +1 -1
- mindspore/profiler/profiling.py +396 -194
- mindspore/rewrite/__init__.py +6 -2
- mindspore/rewrite/api/node.py +51 -110
- mindspore/rewrite/api/node_type.py +10 -6
- mindspore/rewrite/api/pattern_engine.py +51 -7
- mindspore/rewrite/api/scoped_value.py +64 -53
- mindspore/rewrite/api/symbol_tree.py +108 -61
- mindspore/rewrite/api/tree_node_helper.py +2 -3
- mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
- mindspore/rewrite/ast_helpers/__init__.py +6 -3
- mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
- mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
- mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
- mindspore/rewrite/ast_transformers/__init__.py +0 -1
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
- mindspore/rewrite/common/__init__.py +2 -0
- mindspore/rewrite/common/event.py +1 -1
- mindspore/rewrite/common/observable.py +1 -1
- mindspore/rewrite/common/observer.py +1 -1
- mindspore/rewrite/common/rewrite_elog.py +35 -0
- mindspore/rewrite/namer.py +2 -2
- mindspore/rewrite/namespace.py +14 -4
- mindspore/rewrite/node.py +161 -13
- mindspore/rewrite/parser.py +0 -1
- mindspore/rewrite/parser_register.py +0 -1
- mindspore/rewrite/parsers/arguments_parser.py +3 -2
- mindspore/rewrite/parsers/assign_parser.py +267 -67
- mindspore/rewrite/parsers/attribute_parser.py +56 -0
- mindspore/rewrite/parsers/class_def_parser.py +191 -108
- mindspore/rewrite/parsers/constant_parser.py +101 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/for_parser.py +28 -15
- mindspore/rewrite/parsers/function_def_parser.py +21 -5
- mindspore/rewrite/parsers/if_parser.py +11 -28
- mindspore/rewrite/parsers/module_parser.py +9 -6
- mindspore/rewrite/parsers/return_parser.py +3 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +322 -109
- mindspore/rewrite/symbol_tree_builder.py +45 -8
- mindspore/rewrite/symbol_tree_dumper.py +0 -1
- mindspore/rewrite/topological_manager.py +1 -2
- mindspore/run_check/_check_version.py +209 -112
- mindspore/run_check/run_check.py +2 -1
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +6 -4
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +321 -50
- mindspore/train/callback/__init__.py +3 -1
- mindspore/train/callback/_backup_and_restore.py +120 -0
- mindspore/train/callback/_callback.py +8 -8
- mindspore/train/callback/_checkpoint.py +12 -9
- mindspore/train/callback/_early_stop.py +13 -7
- mindspore/train/callback/_history.py +8 -8
- mindspore/train/callback/_lambda_callback.py +6 -6
- mindspore/train/callback/_landscape.py +36 -38
- mindspore/train/callback/_loss_monitor.py +12 -6
- mindspore/train/callback/_lr_scheduler_callback.py +2 -4
- mindspore/train/callback/_on_request_exit.py +212 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
- mindspore/train/callback/_summary_collector.py +27 -19
- mindspore/train/callback/_time_monitor.py +13 -7
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +122 -33
- mindspore/train/dataset_helper.py +28 -87
- mindspore/train/loss_scale_manager.py +4 -7
- mindspore/{nn → train}/metrics/__init__.py +20 -20
- mindspore/{nn → train}/metrics/accuracy.py +12 -10
- mindspore/{nn → train}/metrics/auc.py +4 -4
- mindspore/{nn → train}/metrics/bleu_score.py +4 -4
- mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
- mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
- mindspore/{nn → train}/metrics/dice.py +6 -5
- mindspore/{nn → train}/metrics/error.py +7 -5
- mindspore/{nn → train}/metrics/fbeta.py +9 -7
- mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
- mindspore/{nn → train}/metrics/loss.py +4 -3
- mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/metric.py +6 -5
- mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
- mindspore/{nn → train}/metrics/perplexity.py +5 -4
- mindspore/{nn → train}/metrics/precision.py +5 -4
- mindspore/{nn → train}/metrics/recall.py +5 -4
- mindspore/{nn → train}/metrics/roc.py +7 -6
- mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/topk.py +7 -5
- mindspore/train/mind_ir_pb2.py +339 -32
- mindspore/train/model.py +113 -84
- mindspore/train/serialization.py +547 -167
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -12
- mindspore/train/train_thor/convert_utils.py +7 -1
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/train/train_thor/model_thor.py +0 -4
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +901 -660
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -514
- mindspore/compression/quant/qat.py +0 -636
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/libatomic-1.dll +0 -0
- mindspore/libgcc_s_seh-1.dll +0 -0
- mindspore/libgfortran-4.dll +0 -0
- mindspore/libgomp-1.dll +0 -0
- mindspore/libjpeg-62.dll +0 -0
- mindspore/libmindspore.dll +0 -0
- mindspore/libmindspore_common.dll +0 -0
- mindspore/libmindspore_core.dll +0 -0
- mindspore/libmindspore_glog.dll +0 -0
- mindspore/libnnacl.dll +0 -0
- mindspore/libopencv_core452.dll +0 -0
- mindspore/libopencv_imgcodecs452.dll +0 -0
- mindspore/libopencv_imgproc452.dll +0 -0
- mindspore/libquadmath-0.dll +0 -0
- mindspore/libsqlite3.dll +0 -0
- mindspore/libssp-0.dll +0 -0
- mindspore/libstdc++-6.dll +0 -0
- mindspore/libtinyxml2.dll +0 -0
- mindspore/libturbojpeg.dll +0 -0
- mindspore/libwinpthread-1.dll +0 -0
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -138
- mindspore/nn/probability/dpn/vae/vae.py +0 -122
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
- mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
- mindspore/ops/composite/array_ops.py +0 -210
- mindspore/ops/composite/clip_ops.py +0 -238
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/ops/operations/sponge_ops.py +0 -3531
- mindspore/ops/operations/sponge_update_ops.py +0 -2546
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- mindspore/run_check/_check_deps_version.py +0 -84
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -16,25 +16,26 @@
|
|
|
16
16
|
"""array_ops vmap impl."""
|
|
17
17
|
from __future__ import absolute_import
|
|
18
18
|
|
|
19
|
-
import numpy as np
|
|
20
19
|
import mindspore
|
|
21
20
|
import mindspore.numpy as mnp
|
|
22
21
|
from mindspore import ops
|
|
23
22
|
from mindspore.common import Tensor
|
|
23
|
+
from mindspore._c_expression import Tensor as Tensor_
|
|
24
24
|
from mindspore.ops import operations as P
|
|
25
25
|
from mindspore.ops import functional as F
|
|
26
|
-
from mindspore.ops import constexpr
|
|
26
|
+
from mindspore.ops.primitive import constexpr, _primexpr
|
|
27
27
|
from mindspore.ops.operations._grad_ops import MaskedSelectGrad
|
|
28
28
|
from mindspore.ops.operations import _grad_ops as G
|
|
29
29
|
from mindspore.ops.operations.array_ops import Fills, UniqueConsecutive, Col2Im, NonZero, IndexFill, \
|
|
30
30
|
TensorScatterElements
|
|
31
31
|
from mindspore.ops.operations.random_ops import RandomPoisson
|
|
32
|
+
from mindspore.ops.operations._inner_ops import DynamicBroadcastTo
|
|
32
33
|
from mindspore.ops.primitive import Primitive
|
|
33
34
|
from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim_at_front, \
|
|
34
35
|
_raise_value_error, _vmap_clone_prim, _handle_broadcasting, get_unsupported_dynamic_vmap_rule, _broadcast_by_axis, \
|
|
35
|
-
get_unop_vmap_rule, _get_reduce_out_dim, _get_reduce_batch_axis,
|
|
36
|
+
get_unop_vmap_rule, _get_reduce_out_dim, _get_reduce_batch_axis, \
|
|
36
37
|
_bdim_at_any
|
|
37
|
-
from mindspore.ops.
|
|
38
|
+
from mindspore.ops.function import _VmapGeneralRule
|
|
38
39
|
|
|
39
40
|
|
|
40
41
|
@vmap_rules_getters.register(P.NoRepeatNGram)
|
|
@@ -137,7 +138,7 @@ def get_arg_min_max_with_value_vmap_rule(prim, axis_size):
|
|
|
137
138
|
return vmap_rule
|
|
138
139
|
|
|
139
140
|
|
|
140
|
-
@
|
|
141
|
+
@_primexpr
|
|
141
142
|
def _get_prefix(indices_shape, axis_size, indices_dtype):
|
|
142
143
|
"""
|
|
143
144
|
Generate prefix by indices shape, whose -1 axis value is the index value of axis 0.
|
|
@@ -147,14 +148,16 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
|
|
|
147
148
|
the generated prefix is a Tensor([[[0], [0]],
|
|
148
149
|
[[1], [1]]])
|
|
149
150
|
"""
|
|
150
|
-
|
|
151
|
-
|
|
151
|
+
def _check(indices_shape):
|
|
152
|
+
if not indices_shape:
|
|
153
|
+
raise ValueError("indices_shape is empty in _get_prefix.")
|
|
152
154
|
|
|
155
|
+
_check(indices_shape)
|
|
153
156
|
indices_len = len(indices_shape)
|
|
154
|
-
|
|
155
157
|
if indices_len == 1:
|
|
156
|
-
prefix =
|
|
157
|
-
|
|
158
|
+
prefix = P.Range()(Tensor(0, indices_dtype), P.Fill()(
|
|
159
|
+
indices_dtype, (), axis_size), Tensor(1, indices_dtype))
|
|
160
|
+
return prefix
|
|
158
161
|
|
|
159
162
|
indices_end = indices_len - 1
|
|
160
163
|
prefix_shape = ()
|
|
@@ -169,8 +172,9 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
|
|
|
169
172
|
else:
|
|
170
173
|
expand_shape = expand_shape + (1,)
|
|
171
174
|
|
|
172
|
-
prefix =
|
|
173
|
-
|
|
175
|
+
prefix = P.BroadcastTo(prefix_shape)(P.Reshape()(P.Range()(Tensor(
|
|
176
|
+
0, indices_dtype), Tensor(axis_size, indices_dtype), Tensor(1, indices_dtype)), expand_shape))
|
|
177
|
+
return prefix
|
|
174
178
|
|
|
175
179
|
|
|
176
180
|
@vmap_rules_getters.register(P.Transpose)
|
|
@@ -179,7 +183,7 @@ def get_transpose_vmap_rule(prim, axis_size):
|
|
|
179
183
|
if isinstance(prim, str):
|
|
180
184
|
prim = Primitive(prim)
|
|
181
185
|
|
|
182
|
-
@
|
|
186
|
+
@_primexpr
|
|
183
187
|
def _get_transpose_batch_perm(dim, perm, x_rank):
|
|
184
188
|
"""Generate batch_perm based on the original perm of transpose operation and dim of the input."""
|
|
185
189
|
if dim < 0:
|
|
@@ -223,24 +227,20 @@ def get_tile_vmap_rule(prim, axis_size):
|
|
|
223
227
|
if isinstance(prim, str):
|
|
224
228
|
prim = Primitive(prim)
|
|
225
229
|
|
|
226
|
-
@
|
|
227
|
-
def
|
|
230
|
+
@_primexpr
|
|
231
|
+
def _get_batch_multiples(input_shape, dim, multiples):
|
|
228
232
|
input_ndim = len(input_shape)
|
|
229
233
|
multiples_ndim = len(multiples)
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
for k in pair
|
|
241
|
-
])
|
|
242
|
-
output_shape = tuple([a * b for a, b in zip(input_shape, multiples)])
|
|
243
|
-
return input_expand_shape, repeat_shape, output_shape
|
|
234
|
+
if multiples_ndim < input_ndim - 1:
|
|
235
|
+
multiples = (1,) * (input_ndim - 1 - multiples_ndim) + multiples
|
|
236
|
+
|
|
237
|
+
rev_dim = input_ndim - 1 - dim
|
|
238
|
+
if rev_dim == 0:
|
|
239
|
+
return multiples + (1,), multiples_ndim
|
|
240
|
+
|
|
241
|
+
batch_multiples = list(multiples)
|
|
242
|
+
batch_multiples.insert(-rev_dim, 1)
|
|
243
|
+
return tuple(batch_multiples), multiples_ndim - rev_dim
|
|
244
244
|
|
|
245
245
|
def vmap_rule(input_bdim, multiples_bdim):
|
|
246
246
|
is_all_none, result = vmap_general_preprocess(prim, input_bdim, multiples_bdim)
|
|
@@ -252,13 +252,10 @@ def get_tile_vmap_rule(prim, axis_size):
|
|
|
252
252
|
if multiples_dim is not None:
|
|
253
253
|
_raise_value_error("The source axis of shape in `Tile` must be None, but got {}.".format(multiples_dim))
|
|
254
254
|
|
|
255
|
-
input_x = _bdim_at_front(input_x, dim, axis_size)
|
|
256
255
|
input_shape = F.shape(input_x)
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
repeat_tensor
|
|
260
|
-
output = F.reshape(repeat_tensor, output_shape)
|
|
261
|
-
return output, 0
|
|
256
|
+
batch_multiples, out_dim = _get_batch_multiples(input_shape, dim, multiples)
|
|
257
|
+
repeat_tensor = P.Tile()(input_x, batch_multiples)
|
|
258
|
+
return repeat_tensor, out_dim
|
|
262
259
|
|
|
263
260
|
return vmap_rule
|
|
264
261
|
|
|
@@ -359,8 +356,13 @@ def get_unstack_vmap_rule(prim, axis_size):
|
|
|
359
356
|
def get_reshape_vmap_rule(prim, axis_size):
|
|
360
357
|
"""VmapRule for `Reshape` operation."""
|
|
361
358
|
|
|
362
|
-
|
|
359
|
+
|
|
360
|
+
@_primexpr
|
|
363
361
|
def get_batch_shape(x_shape, x_dim, target_shape, axis_size):
|
|
362
|
+
def _check(neg_index, target_shape):
|
|
363
|
+
if neg_index != -1:
|
|
364
|
+
raise ValueError(f'The shape can only has one -1 at most, but {target_shape}.')
|
|
365
|
+
|
|
364
366
|
if x_dim == 0:
|
|
365
367
|
return (axis_size,) + target_shape, 0, False
|
|
366
368
|
|
|
@@ -371,19 +373,21 @@ def get_reshape_vmap_rule(prim, axis_size):
|
|
|
371
373
|
dim_prod = 1
|
|
372
374
|
for i, shp_i in enumerate(target_shape):
|
|
373
375
|
if shp_i == -1:
|
|
374
|
-
|
|
375
|
-
raise ValueError(f'The shape can only has one -1 at most, but {target_shape}.')
|
|
376
|
+
_check(neg_index, target_shape)
|
|
376
377
|
neg_index = i
|
|
377
378
|
else:
|
|
378
379
|
dim_prod *= shp_i
|
|
379
|
-
arr_prod =
|
|
380
|
+
arr_prod = 1
|
|
381
|
+
for i in x_shape:
|
|
382
|
+
arr_prod *= i
|
|
380
383
|
target_shape_list = list(target_shape)
|
|
381
384
|
if neg_index != -1:
|
|
382
385
|
neg_index_size = int(arr_prod // (dim_prod * axis_size))
|
|
383
386
|
target_shape_list[neg_index] = neg_index_size
|
|
384
387
|
|
|
385
|
-
arr_prod_before_dim =
|
|
386
|
-
|
|
388
|
+
arr_prod_before_dim = 1
|
|
389
|
+
for i in x_shape[:x_dim]:
|
|
390
|
+
arr_prod_before_dim *= i
|
|
387
391
|
dim_prod = 1
|
|
388
392
|
for i, shp_i in enumerate(target_shape_list, start=1):
|
|
389
393
|
dim_prod *= shp_i
|
|
@@ -428,7 +432,7 @@ def get_reverse_sequence_vmap_rule(prim, axis_size):
|
|
|
428
432
|
batch_dim = prim.batch_dim_
|
|
429
433
|
seq_dim = prim.seq_dim_
|
|
430
434
|
|
|
431
|
-
@
|
|
435
|
+
@_primexpr
|
|
432
436
|
def get_batch_seq_dim(dim, batch_dim_, seq_dim_):
|
|
433
437
|
if dim is None:
|
|
434
438
|
batch_dim_ += 1
|
|
@@ -444,7 +448,7 @@ def get_reverse_sequence_vmap_rule(prim, axis_size):
|
|
|
444
448
|
seq_dim_ += 1
|
|
445
449
|
return batch_dim_, seq_dim_
|
|
446
450
|
|
|
447
|
-
@
|
|
451
|
+
@_primexpr
|
|
448
452
|
def get_seq_dim(dim, batch_dim_, seq_dim_):
|
|
449
453
|
if dim is None:
|
|
450
454
|
return seq_dim_
|
|
@@ -564,20 +568,19 @@ def get_scatter_nd_vmap_rule(prim, axis_size):
|
|
|
564
568
|
Reshape the output tensor to `[10, 6, 4, 5]`
|
|
565
569
|
"""
|
|
566
570
|
|
|
567
|
-
@
|
|
571
|
+
@_primexpr
|
|
568
572
|
def _refine_shape(shape, bdim_size):
|
|
569
573
|
offset = shape[0]
|
|
570
574
|
return (bdim_size * shape[0],) + tuple(shape[1:]), offset, (bdim_size,) + tuple(shape)
|
|
571
575
|
|
|
572
|
-
@
|
|
576
|
+
@_primexpr
|
|
573
577
|
def _gen_indices_offset(shape, offset):
|
|
574
578
|
# original rank(indices.shape) is required >= 2, so indices with batch dim's rank >= 3.
|
|
575
|
-
shape =
|
|
576
|
-
val =
|
|
577
|
-
val = np.reshape(val, (shape[0], shape[-1]))
|
|
579
|
+
shape = (shape[0],) + (1,) * (len(shape) - 2) + (shape[-1],)
|
|
580
|
+
val = P.Zeros()((shape[0], shape[-1]), mindspore.int32)
|
|
578
581
|
for i in range(shape[0]):
|
|
579
582
|
val[i, 0] = i * offset
|
|
580
|
-
return
|
|
583
|
+
return P.Reshape()(val, shape)
|
|
581
584
|
|
|
582
585
|
if isinstance(prim, str):
|
|
583
586
|
prim = Primitive(prim)
|
|
@@ -598,7 +601,7 @@ def get_scatter_nd_vmap_rule(prim, axis_size):
|
|
|
598
601
|
indices_shape = F.shape(indices)
|
|
599
602
|
indices_dtype = F.dtype(indices)
|
|
600
603
|
offset_val = _gen_indices_offset(indices_shape, offset)
|
|
601
|
-
indices_offset =
|
|
604
|
+
indices_offset = P.Cast()(offset_val, indices_dtype)
|
|
602
605
|
new_indices = P.Add()(indices, indices_offset)
|
|
603
606
|
out = prim(new_indices, updates, new_shape)
|
|
604
607
|
real_out = P.Reshape()(out, out_shape)
|
|
@@ -846,6 +849,62 @@ def get_fill_vmap_rule(prim, axis_size):
|
|
|
846
849
|
return vmap_rule
|
|
847
850
|
|
|
848
851
|
|
|
852
|
+
@constexpr
|
|
853
|
+
def to_tensor_with_type(x, type):
|
|
854
|
+
"""x to Tensor with type"""
|
|
855
|
+
return Tensor(x, type)
|
|
856
|
+
|
|
857
|
+
|
|
858
|
+
@vmap_rules_getters.register(P.FillV2)
|
|
859
|
+
def get_fill_v2_vmap_rule(prim, axis_size):
|
|
860
|
+
"""VmapRule for `FillV2` operation."""
|
|
861
|
+
if isinstance(prim, str):
|
|
862
|
+
prim = Primitive(prim)
|
|
863
|
+
|
|
864
|
+
def vmap_rule(shape_bdim, value_bdim):
|
|
865
|
+
is_all_none, result = vmap_general_preprocess(prim, shape_bdim, value_bdim)
|
|
866
|
+
if is_all_none:
|
|
867
|
+
return result
|
|
868
|
+
|
|
869
|
+
value_shape, shape_dim = shape_bdim
|
|
870
|
+
if shape_dim is not None:
|
|
871
|
+
_raise_value_error(
|
|
872
|
+
"The source axis of `shape` in `P.FillV2` must be None, but got {}."
|
|
873
|
+
.format(shape_dim))
|
|
874
|
+
|
|
875
|
+
value, vdim = value_bdim
|
|
876
|
+
value_rank = F.rank(value)
|
|
877
|
+
if value_rank != 1 or vdim != 0:
|
|
878
|
+
_raise_value_error(
|
|
879
|
+
"The `value` in `P.FillV2` must be constant value, thus the value only "
|
|
880
|
+
"can be rank: 1 with source axis: 0 in vmap scope, but got value rank: "
|
|
881
|
+
"{} with source axis: {}.".format(value_rank, vdim))
|
|
882
|
+
value = F.reshape(value, (axis_size,) + (1,) * len(value_shape))
|
|
883
|
+
|
|
884
|
+
out = None
|
|
885
|
+
if isinstance(value_shape, (Tensor_, Tensor)):
|
|
886
|
+
value_shape_rank = F.rank(value_shape)
|
|
887
|
+
if value_shape_rank != 1:
|
|
888
|
+
_raise_value_error(
|
|
889
|
+
"The `shape` in `P.FillV2` must be 1-D tensor, thus the shape only "
|
|
890
|
+
"can be rank: 1, but got shape rank: "
|
|
891
|
+
"{}.".format(value_shape_rank))
|
|
892
|
+
axis_size_tensor = to_tensor_with_type((axis_size,),
|
|
893
|
+
F.dtype(value_shape))
|
|
894
|
+
broad_cast_shape = F.concat((axis_size_tensor, value_shape))
|
|
895
|
+
out = DynamicBroadcastTo()(value, broad_cast_shape)
|
|
896
|
+
elif isinstance(value_shape, tuple):
|
|
897
|
+
out = P.BroadcastTo((axis_size,) + value_shape)(value)
|
|
898
|
+
else:
|
|
899
|
+
_raise_value_error(
|
|
900
|
+
f"For `P.FillV2`, the input `shape` should be Tuple or Tensor, but got `shape`: {value_shape}."
|
|
901
|
+
)
|
|
902
|
+
|
|
903
|
+
return out, 0
|
|
904
|
+
|
|
905
|
+
return vmap_rule
|
|
906
|
+
|
|
907
|
+
|
|
849
908
|
@vmap_rules_getters.register(Fills)
|
|
850
909
|
def get_fills_vmap_rule(prim, axis_size):
|
|
851
910
|
"""VmapRule for `Fills` operation."""
|
|
@@ -1299,12 +1358,7 @@ def get_gatherd_grad_v2_vmap_rule(prim, axis_size):
|
|
|
1299
1358
|
if isinstance(prim, str):
|
|
1300
1359
|
prim = Primitive(prim)
|
|
1301
1360
|
|
|
1302
|
-
dim
|
|
1303
|
-
if hasattr(prim, 'dim'):
|
|
1304
|
-
dim = prim.dim
|
|
1305
|
-
|
|
1306
|
-
@constexpr
|
|
1307
|
-
def _update_attr(x_rank, batch_dim):
|
|
1361
|
+
def _update_dim(dim, x_rank, batch_dim):
|
|
1308
1362
|
pdim = dim
|
|
1309
1363
|
if pdim < 0:
|
|
1310
1364
|
pdim += x_rank
|
|
@@ -1312,19 +1366,22 @@ def get_gatherd_grad_v2_vmap_rule(prim, axis_size):
|
|
|
1312
1366
|
_raise_value_error(
|
|
1313
1367
|
"The `dim` in `GatherDGradV2` must be in range [{}, {}], but got {}.".format(-x_rank, x_rank - 1, dim))
|
|
1314
1368
|
if pdim >= batch_dim:
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
|
|
1369
|
+
return pdim + 1
|
|
1370
|
+
if dim < 0:
|
|
1371
|
+
return pdim
|
|
1372
|
+
return dim
|
|
1318
1373
|
|
|
1319
|
-
def vmap_rule(x_bdim, index_bdim, grad_bdim):
|
|
1320
|
-
is_all_none, result = vmap_general_preprocess(prim, x_bdim, index_bdim, grad_bdim)
|
|
1374
|
+
def vmap_rule(x_bdim, dim_bdim, index_bdim, grad_bdim):
|
|
1375
|
+
is_all_none, result = vmap_general_preprocess(prim, x_bdim, dim_bdim, index_bdim, grad_bdim)
|
|
1321
1376
|
if is_all_none:
|
|
1322
1377
|
return result
|
|
1323
1378
|
|
|
1324
1379
|
x, x_dim = x_bdim
|
|
1380
|
+
dim, dim_dim = dim_bdim
|
|
1381
|
+
if dim_dim is not None:
|
|
1382
|
+
_raise_value_error("The dim of 'dim' in `GatherDGradV2` must be None, but got {}.".format(dim_dim))
|
|
1325
1383
|
index, index_dim = index_bdim
|
|
1326
1384
|
grad, grad_dim = grad_bdim
|
|
1327
|
-
|
|
1328
1385
|
batch_dim = 0
|
|
1329
1386
|
if x_dim is not None:
|
|
1330
1387
|
batch_dim = x_dim
|
|
@@ -1336,12 +1393,10 @@ def get_gatherd_grad_v2_vmap_rule(prim, axis_size):
|
|
|
1336
1393
|
x = _bdim_at_any(x, x_dim, batch_dim, axis_size)
|
|
1337
1394
|
index = _bdim_at_any(index, index_dim, batch_dim, axis_size)
|
|
1338
1395
|
grad = _bdim_at_any(grad, grad_dim, batch_dim, axis_size)
|
|
1339
|
-
|
|
1340
|
-
# Adjust dim-attr if needed
|
|
1341
1396
|
x_rank = F.rank(x) - 1
|
|
1342
|
-
|
|
1343
|
-
|
|
1344
|
-
out = prim(x, index, grad)
|
|
1397
|
+
# Adjust dim if needed
|
|
1398
|
+
dim = _update_dim(dim, x_rank, batch_dim)
|
|
1399
|
+
out = prim(x, dim, index, grad)
|
|
1345
1400
|
return (out, batch_dim)
|
|
1346
1401
|
|
|
1347
1402
|
return vmap_rule
|
|
@@ -1425,6 +1480,7 @@ def get_meshgrid_vmap_rule(prim, axis_size):
|
|
|
1425
1480
|
"The input number of P.Meshgrid must be greater than 1.")
|
|
1426
1481
|
|
|
1427
1482
|
output_shape = []
|
|
1483
|
+
ones_shape = []
|
|
1428
1484
|
for each_arg in args:
|
|
1429
1485
|
x, bdim = each_arg
|
|
1430
1486
|
if bdim is None:
|
|
@@ -1435,19 +1491,30 @@ def get_meshgrid_vmap_rule(prim, axis_size):
|
|
|
1435
1491
|
_raise_value_error(
|
|
1436
1492
|
"Each input of Meshgrid must be 1D, but got {}.".format(F.rank(x) - 1))
|
|
1437
1493
|
output_shape.append(F.shape(x)[-1])
|
|
1494
|
+
ones_shape.append(1)
|
|
1438
1495
|
output_shape.insert(0, axis_size)
|
|
1496
|
+
ones_shape.insert(0, axis_size)
|
|
1439
1497
|
|
|
1440
1498
|
if indexing == "xy":
|
|
1441
1499
|
output_shape[1], output_shape[2] = output_shape[2], output_shape[1]
|
|
1442
|
-
|
|
1443
1500
|
shape = tuple(output_shape)
|
|
1501
|
+
|
|
1502
|
+
input_0, _ = args[0]
|
|
1503
|
+
dtype = F.dtype(input_0)
|
|
1504
|
+
ones_tensor = F.fill(dtype, shape, 1)
|
|
1505
|
+
|
|
1506
|
+
index = 0
|
|
1444
1507
|
vals_out_tuple = ()
|
|
1445
1508
|
for each_arg in args:
|
|
1446
1509
|
x, bdim = each_arg
|
|
1447
1510
|
x = _bdim_at_front(x, bdim, axis_size)
|
|
1448
|
-
|
|
1449
|
-
|
|
1511
|
+
shape_index = (1 - index) if (index <= 1 and indexing == "xy") else index
|
|
1512
|
+
ones_shape[shape_index + 1] = output_shape[shape_index + 1]
|
|
1513
|
+
x = P.Reshape()(x, tuple(ones_shape))
|
|
1514
|
+
output = P.Mul()(x, ones_tensor)
|
|
1450
1515
|
vals_out_tuple = vals_out_tuple + ((output, 0),)
|
|
1516
|
+
ones_shape[shape_index + 1] = 1
|
|
1517
|
+
index = index + 1
|
|
1451
1518
|
|
|
1452
1519
|
return vals_out_tuple
|
|
1453
1520
|
|
|
@@ -1491,7 +1558,7 @@ def get_gather_vmap_rule(prim, axis_size):
|
|
|
1491
1558
|
else:
|
|
1492
1559
|
prim_name = prim.name
|
|
1493
1560
|
|
|
1494
|
-
@
|
|
1561
|
+
@_primexpr
|
|
1495
1562
|
def process_axis(axis, x_shape_size, has_xdim: bool, has_idim: bool):
|
|
1496
1563
|
if has_xdim and has_idim:
|
|
1497
1564
|
if axis < 0:
|
|
@@ -1505,7 +1572,7 @@ def get_gather_vmap_rule(prim, axis_size):
|
|
|
1505
1572
|
|
|
1506
1573
|
return axis
|
|
1507
1574
|
|
|
1508
|
-
@
|
|
1575
|
+
@_primexpr
|
|
1509
1576
|
def get_x_dst_shape(x_shape, axis):
|
|
1510
1577
|
target_axis_size = x_shape[axis + 1]
|
|
1511
1578
|
x_dst_shape = x_shape[0:axis] + (axis_size * target_axis_size,) + x_shape[axis + 2:]
|
|
@@ -1705,7 +1772,7 @@ def get_data_format_dim_map_vmap_rule(prim, axis_size):
|
|
|
1705
1772
|
def get_expand_dims_vmap_rule(prim, axis_size):
|
|
1706
1773
|
"""VmapRule for `ExpandDims`."""
|
|
1707
1774
|
|
|
1708
|
-
@
|
|
1775
|
+
@_primexpr
|
|
1709
1776
|
def process_axis(axis, rank, x_dim):
|
|
1710
1777
|
if axis < 0:
|
|
1711
1778
|
axis += rank
|
|
@@ -1799,7 +1866,7 @@ def get_squeeze_vmap_rule(prim, axis_size):
|
|
|
1799
1866
|
else:
|
|
1800
1867
|
prim_axis = None
|
|
1801
1868
|
|
|
1802
|
-
@
|
|
1869
|
+
@_primexpr
|
|
1803
1870
|
def move_axis(axes):
|
|
1804
1871
|
new_axis = ()
|
|
1805
1872
|
for axis in axes:
|
|
@@ -1809,7 +1876,7 @@ def get_squeeze_vmap_rule(prim, axis_size):
|
|
|
1809
1876
|
new_axis = new_axis + (axis + 1,)
|
|
1810
1877
|
return new_axis
|
|
1811
1878
|
|
|
1812
|
-
@
|
|
1879
|
+
@_primexpr
|
|
1813
1880
|
def generate_all_axis_except_first(x_rank):
|
|
1814
1881
|
new_axis = ()
|
|
1815
1882
|
for i in range(1, x_rank, 1):
|
|
@@ -1838,6 +1905,7 @@ def get_squeeze_vmap_rule(prim, axis_size):
|
|
|
1838
1905
|
batch_squeeze = P.Squeeze(axis=new_axis)
|
|
1839
1906
|
out = batch_squeeze(x)
|
|
1840
1907
|
return out, 0
|
|
1908
|
+
|
|
1841
1909
|
return vmap_rule
|
|
1842
1910
|
|
|
1843
1911
|
|
|
@@ -1852,7 +1920,7 @@ def get_stridedslice_vmap_rule(prim, axis_size):
|
|
|
1852
1920
|
batch_stridedslice = P.StridedSlice(new_begin_mask, new_end_mask, new_ellipsis_mask, new_new_axis_mask, \
|
|
1853
1921
|
new_shrink_axis_mask)
|
|
1854
1922
|
|
|
1855
|
-
@
|
|
1923
|
+
@_primexpr
|
|
1856
1924
|
def get_new_begin_end_strided(begin, end, strided):
|
|
1857
1925
|
new_begin = (0,) + begin
|
|
1858
1926
|
new_end = (0,) + end
|
|
@@ -1891,9 +1959,9 @@ def get_stridedslice_grad_vmap_rule(prim, axis_size):
|
|
|
1891
1959
|
new_new_axis_mask = prim.new_axis_mask * 2
|
|
1892
1960
|
new_shrink_axis_mask = prim.shrink_axis_mask * 2
|
|
1893
1961
|
batch_stridedslice_grad = G.StridedSliceGrad(new_begin_mask, new_end_mask, new_ellipsis_mask, new_new_axis_mask, \
|
|
1894
|
-
|
|
1962
|
+
new_shrink_axis_mask)
|
|
1895
1963
|
|
|
1896
|
-
@
|
|
1964
|
+
@_primexpr
|
|
1897
1965
|
def get_new_xshape_begin_end_strided(xshape, begin, end, strided):
|
|
1898
1966
|
new_xshape = (axis_size,) + xshape
|
|
1899
1967
|
new_begin = (0,) + begin
|
|
@@ -1984,6 +2052,30 @@ def get_im2col_vmap_rule(prim, axis_size):
|
|
|
1984
2052
|
return vmap_rule
|
|
1985
2053
|
|
|
1986
2054
|
|
|
2055
|
+
@vmap_rules_getters.register(P.Split)
|
|
2056
|
+
def get_split_vmap_rule(prim, axis_size):
|
|
2057
|
+
"""VmapRule for `Split`."""
|
|
2058
|
+
|
|
2059
|
+
axis = prim.axis
|
|
2060
|
+
if axis >= 0:
|
|
2061
|
+
axis += 1
|
|
2062
|
+
batch_prim = P.Split(axis, prim.output_num)
|
|
2063
|
+
|
|
2064
|
+
def vmap_rule(x_bdim):
|
|
2065
|
+
is_all_none, result = vmap_general_preprocess(prim, x_bdim)
|
|
2066
|
+
if is_all_none:
|
|
2067
|
+
return result
|
|
2068
|
+
x, x_dim = x_bdim
|
|
2069
|
+
x = _bdim_at_front(x, x_dim, axis_size)
|
|
2070
|
+
outputs = batch_prim(x)
|
|
2071
|
+
output = ()
|
|
2072
|
+
for out in outputs:
|
|
2073
|
+
output = output + ((out, 0),)
|
|
2074
|
+
return output
|
|
2075
|
+
|
|
2076
|
+
return vmap_rule
|
|
2077
|
+
|
|
2078
|
+
|
|
1987
2079
|
get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(NonZero)(get_unsupported_dynamic_vmap_rule)
|
|
1988
2080
|
get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(P.Unique)(get_unsupported_dynamic_vmap_rule)
|
|
1989
2081
|
get_unsupported_dynamic_vmap_rule = \
|
mindspore/ops/_vmap/vmap_base.py
CHANGED
|
@@ -21,12 +21,13 @@ from mindspore.common import Tensor
|
|
|
21
21
|
from mindspore.ops import operations as P
|
|
22
22
|
from mindspore.ops import functional as F
|
|
23
23
|
from mindspore.ops import constexpr
|
|
24
|
+
from mindspore.ops.primitive import _primexpr
|
|
24
25
|
from mindspore.ops.operations import math_ops
|
|
25
26
|
from mindspore.ops.operations import _grad_ops as G
|
|
26
27
|
from mindspore.ops.operations import nn_ops as nps
|
|
27
|
-
from mindspore.ops.
|
|
28
|
-
from mindspore.ops.primitive import Primitive
|
|
29
|
-
from mindspore.ops.operations.random_ops import UniformCandidateSampler
|
|
28
|
+
from mindspore.ops.function import _VmapGeneralPreprocess
|
|
29
|
+
from mindspore.ops.primitive import Primitive, _PrimitiveC
|
|
30
|
+
from mindspore.ops.operations.random_ops import UniformCandidateSampler, RandomShuffle
|
|
30
31
|
from mindspore.ops._grad.grad_base import BpropRegistry as VmapRuleRegistry
|
|
31
32
|
|
|
32
33
|
|
|
@@ -41,7 +42,7 @@ def get_vmap_rule(prim, axis_size):
|
|
|
41
42
|
return None
|
|
42
43
|
|
|
43
44
|
|
|
44
|
-
@
|
|
45
|
+
@_primexpr
|
|
45
46
|
def _get_broadcast_shape_with_front_axis(x_shape, y_shape):
|
|
46
47
|
""" Explicitly matched with the broadcast shape, that is, 1 is added to the broadcast position. """
|
|
47
48
|
x_len = len(x_shape)
|
|
@@ -86,7 +87,7 @@ def _handle_broadcasting(x, x_shape, y_shape):
|
|
|
86
87
|
return F.reshape(x, broadcast_shape)
|
|
87
88
|
|
|
88
89
|
|
|
89
|
-
@
|
|
90
|
+
@_primexpr
|
|
90
91
|
def _get_broadcasting_with_front_axis_additional_axis(x_shape, y_shape):
|
|
91
92
|
""" Get the axes that are inserted after broadcasting.
|
|
92
93
|
Args:
|
|
@@ -129,15 +130,19 @@ def _raise_value_error(info, param=None):
|
|
|
129
130
|
raise ValueError(info + f"{param}")
|
|
130
131
|
|
|
131
132
|
|
|
132
|
-
@
|
|
133
|
+
@_primexpr
|
|
133
134
|
def _get_broadcast_shape(x_shape, dst, axis_size):
|
|
134
135
|
"""Get the target shape for broadcast array."""
|
|
136
|
+
def _check(dst, broadcast_ndim):
|
|
137
|
+
if dst < -broadcast_ndim or dst >= broadcast_ndim:
|
|
138
|
+
_raise_value_error("Destination axis {} is out of bounds for array of dimension"
|
|
139
|
+
" [{}, {}).".format(dst, -broadcast_ndim, broadcast_ndim))
|
|
140
|
+
|
|
135
141
|
x_ndim = len(x_shape)
|
|
136
142
|
broadcast_ndim = x_ndim + 1
|
|
137
143
|
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
" [{}, {}).".format(dst, -broadcast_ndim, broadcast_ndim))
|
|
144
|
+
_check(dst, broadcast_ndim)
|
|
145
|
+
|
|
141
146
|
if dst < 0:
|
|
142
147
|
dst = broadcast_ndim + dst
|
|
143
148
|
|
|
@@ -190,6 +195,10 @@ def vmap_unstack(dim, val):
|
|
|
190
195
|
return P.Unstack(dim)(val)
|
|
191
196
|
|
|
192
197
|
|
|
198
|
+
def vmap_stack(val):
|
|
199
|
+
return P.Stack()(val)
|
|
200
|
+
|
|
201
|
+
|
|
193
202
|
def vmap_general_output_process(output):
|
|
194
203
|
""" Match output to axis 0"""
|
|
195
204
|
vals_out_tuple = ()
|
|
@@ -416,6 +425,8 @@ def _vmap_clone_prim(prim):
|
|
|
416
425
|
"""
|
|
417
426
|
Cloning a new primitive object same as `prim`.
|
|
418
427
|
"""
|
|
428
|
+
if isinstance(prim, _PrimitiveC):
|
|
429
|
+
return _PrimitiveC(prim.name, prim.attrs)
|
|
419
430
|
new_ops = _ops_vmap_clone_prim_dict.get(prim.name, None)
|
|
420
431
|
if new_ops is None:
|
|
421
432
|
raise ValueError("Failed to get the primitive object of {} from `_ops_vmap_clone_prim_dict`. Please register "
|
|
@@ -433,7 +444,7 @@ def _vmap_clone_prim(prim):
|
|
|
433
444
|
return cloned
|
|
434
445
|
|
|
435
446
|
|
|
436
|
-
@
|
|
447
|
+
@_primexpr
|
|
437
448
|
def _get_reduce_batch_axis(axis, x_dim, x_ndim):
|
|
438
449
|
"""get batch_axis for reduce* operation."""
|
|
439
450
|
# For axis, it's value in Union[int, list, tuple]
|
|
@@ -481,6 +492,7 @@ _ops_vmap_clone_prim_dict = {
|
|
|
481
492
|
"ApplyAdaMax": P.ApplyAdaMax,
|
|
482
493
|
"ApplyAdadelta": P.ApplyAdadelta,
|
|
483
494
|
"ApplyRMSProp": P.ApplyRMSProp,
|
|
495
|
+
'Adam': P.Adam,
|
|
484
496
|
"ApplyCenteredRMSProp": P.ApplyCenteredRMSProp,
|
|
485
497
|
"ApplyFtrl": P.ApplyFtrl,
|
|
486
498
|
"ApplyGradientDescent": P.ApplyGradientDescent,
|
|
@@ -508,4 +520,6 @@ _ops_vmap_clone_prim_dict = {
|
|
|
508
520
|
"SparseApplyAdagrad": P.SparseApplyAdagrad,
|
|
509
521
|
"SparseApplyAdagradV2": P.SparseApplyAdagradV2,
|
|
510
522
|
"SparseApplyFtrl": P.SparseApplyFtrl,
|
|
523
|
+
"RandomShuffle": RandomShuffle,
|
|
524
|
+
"RandomChoiceWithMask": P.RandomChoiceWithMask
|
|
511
525
|
}
|
|
@@ -16,9 +16,9 @@
|
|
|
16
16
|
"""convolution vmap impl"""
|
|
17
17
|
from __future__ import absolute_import
|
|
18
18
|
|
|
19
|
-
import numpy as np
|
|
20
19
|
import mindspore.numpy as mnp
|
|
21
20
|
from mindspore.ops import constexpr
|
|
21
|
+
from mindspore.ops.primitive import _primexpr
|
|
22
22
|
from mindspore.ops import operations as P
|
|
23
23
|
from mindspore.ops import functional as F
|
|
24
24
|
from mindspore.ops.operations import nn_ops as nps
|
|
@@ -142,7 +142,7 @@ def get_conv3d_backprop_filter_vmap_rule(prim, axis_size):
|
|
|
142
142
|
return vmap_rule
|
|
143
143
|
|
|
144
144
|
|
|
145
|
-
@
|
|
145
|
+
@_primexpr
|
|
146
146
|
def _get_reshape_src_dim(data_dim, cmp_dim):
|
|
147
147
|
"""Get source dim for reshape"""
|
|
148
148
|
if data_dim > cmp_dim:
|
|
@@ -154,7 +154,7 @@ def _get_reshape_src_dim(data_dim, cmp_dim):
|
|
|
154
154
|
return expand_dim, merge_dim
|
|
155
155
|
|
|
156
156
|
|
|
157
|
-
@
|
|
157
|
+
@_primexpr
|
|
158
158
|
def _get_merge_shape(src_dim, dst_dim, shape):
|
|
159
159
|
"""Get new shape for merging the src_dim and dst_dim. The dst_dim is the value after removing src_dim."""
|
|
160
160
|
new_shape = [shape[i] for i in range(len(shape)) if i != src_dim]
|
|
@@ -171,13 +171,10 @@ def _reshape_merge_dims(src_dim, dst_dim, target):
|
|
|
171
171
|
return output, new_shape
|
|
172
172
|
|
|
173
173
|
|
|
174
|
-
@
|
|
174
|
+
@_primexpr
|
|
175
175
|
def _get_expand_shape(src_dim, dst_size, shape, prim_name):
|
|
176
176
|
"""Get new shape for splitting src_dim into dst_size parts."""
|
|
177
|
-
dst_size2
|
|
178
|
-
if remainder != 0:
|
|
179
|
-
_raise_value_error("The remainder of {} / {} should be 0, "
|
|
180
|
-
"but got {} in {}.".format(shape[src_dim], dst_size, remainder, prim_name))
|
|
177
|
+
dst_size2 = shape[src_dim] // dst_size
|
|
181
178
|
new_shape = list(shape)
|
|
182
179
|
new_shape[src_dim:(src_dim + 1)] = [dst_size, dst_size2]
|
|
183
180
|
return tuple(new_shape)
|
|
@@ -190,7 +187,7 @@ def _reshape_expand_dims(src_dim, dst_size, target, prim_name):
|
|
|
190
187
|
return F.reshape(target, new_shape)
|
|
191
188
|
|
|
192
189
|
|
|
193
|
-
@
|
|
190
|
+
@_primexpr
|
|
194
191
|
def _get_new_size_by_index(input_size, batch_size, index):
|
|
195
192
|
"""Get the new size of input_size by multiplying input_size[index] by batch_size."""
|
|
196
193
|
new_size = ()
|
|
@@ -201,7 +198,7 @@ def _get_new_size_by_index(input_size, batch_size, index):
|
|
|
201
198
|
return tuple(new_size)
|
|
202
199
|
|
|
203
200
|
|
|
204
|
-
@
|
|
201
|
+
@_primexpr
|
|
205
202
|
def _update_group_attr(prim, groups, batch_size):
|
|
206
203
|
"""Set new value for 'group' attribute of the convolution primitive."""
|
|
207
204
|
group = groups * batch_size
|
|
@@ -17,9 +17,9 @@
|
|
|
17
17
|
from __future__ import absolute_import
|
|
18
18
|
|
|
19
19
|
from mindspore.ops import functional as F
|
|
20
|
-
from mindspore.ops import
|
|
20
|
+
from mindspore.ops.primitive import _primexpr
|
|
21
21
|
from mindspore.ops.operations import _grad_ops as G
|
|
22
|
-
from mindspore.ops.
|
|
22
|
+
from mindspore.ops.function import _VmapGeneralRule
|
|
23
23
|
from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim_at_front, \
|
|
24
24
|
_handle_broadcasting, get_unary_grad_vmap_rule, _get_broadcasting_with_front_axis_additional_axis
|
|
25
25
|
|
|
@@ -36,7 +36,7 @@ def get_broadcast_binary_op_grad_vmap_rule(prim, axis_size):
|
|
|
36
36
|
if isinstance(prim, str):
|
|
37
37
|
prim = broadcast_binary_op_grad_map.get(prim)()
|
|
38
38
|
|
|
39
|
-
@
|
|
39
|
+
@_primexpr
|
|
40
40
|
def get_longest_shape(x_shape, y_shape, g_shape):
|
|
41
41
|
x_rank = len(x_shape)
|
|
42
42
|
y_rank = len(y_shape)
|
|
@@ -148,7 +148,7 @@ def get_median_grad_vmap_rule(prim, axis_size):
|
|
|
148
148
|
axis = prim.axis
|
|
149
149
|
keep_dims = prim.keep_dims
|
|
150
150
|
|
|
151
|
-
@
|
|
151
|
+
@_primexpr
|
|
152
152
|
def trans_grad_axis(axis, rank, dim, keep_dims):
|
|
153
153
|
if axis < 0:
|
|
154
154
|
axis += rank - 1
|