mindspore 2.0.0a0__cp37-cp37m-win_amd64.whl → 2.0.0rc1__cp37-cp37m-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -2
- mindspore/_c_dataengine.cp37-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp37-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp37-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +16 -1
- mindspore/_extends/parse/parser.py +107 -22
- mindspore/_extends/parse/resources.py +0 -7
- mindspore/_extends/parse/standard_method.py +885 -413
- mindspore/amp.py +52 -57
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +38 -20
- mindspore/boost/dim_reduce.py +3 -3
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -6
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +41 -7
- mindspore/common/api.py +215 -141
- mindspore/common/dtype.py +8 -1
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +4 -2
- mindspore/common/jit_config.py +17 -13
- mindspore/common/mutable.py +33 -13
- mindspore/common/parameter.py +23 -21
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +62 -41
- mindspore/common/tensor.py +852 -1154
- mindspore/communication/__init__.py +2 -2
- mindspore/communication/_comm_helper.py +11 -4
- mindspore/communication/management.py +22 -21
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +201 -23
- mindspore/dataset/__init__.py +6 -6
- mindspore/dataset/audio/__init__.py +7 -7
- mindspore/dataset/audio/transforms.py +670 -30
- mindspore/dataset/audio/utils.py +47 -4
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +210 -14
- mindspore/dataset/core/validator_helpers.py +2 -2
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +322 -66
- mindspore/dataset/engine/datasets_audio.py +80 -76
- mindspore/dataset/engine/datasets_standard_format.py +51 -38
- mindspore/dataset/engine/datasets_text.py +232 -118
- mindspore/dataset/engine/datasets_user_defined.py +41 -17
- mindspore/dataset/engine/datasets_vision.py +746 -225
- mindspore/dataset/engine/graphdata.py +75 -10
- mindspore/dataset/engine/iterators.py +45 -5
- mindspore/dataset/engine/offload.py +48 -28
- mindspore/dataset/engine/validators.py +117 -8
- mindspore/dataset/text/__init__.py +6 -5
- mindspore/dataset/text/transforms.py +86 -3
- mindspore/dataset/text/utils.py +6 -4
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +3 -2
- mindspore/dataset/transforms/c_transforms.py +1 -1
- mindspore/dataset/transforms/transforms.py +2 -2
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +2 -3
- mindspore/dataset/vision/c_transforms.py +9 -9
- mindspore/dataset/vision/py_transforms.py +5 -5
- mindspore/dataset/vision/py_transforms_util.py +2 -0
- mindspore/dataset/vision/transforms.py +160 -161
- mindspore/dataset/vision/utils.py +3 -3
- mindspore/experimental/map_parameter.py +38 -26
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +44 -9
- mindspore/include/api/delegate.h +1 -1
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_parallel_runner.h +2 -2
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +19 -3
- mindspore/include/api/types.h +3 -3
- mindspore/include/dataset/constants.h +7 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filereader.py +18 -0
- mindspore/mindrecord/filewriter.py +197 -34
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
- mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
- mindspore/mindrecord/tools/csv_to_mr.py +3 -3
- mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/__init__.py +0 -4
- mindspore/nn/cell.py +204 -132
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +7 -6
- mindspore/nn/layer/__init__.py +5 -4
- mindspore/nn/layer/activation.py +40 -89
- mindspore/nn/layer/basic.py +255 -624
- mindspore/nn/layer/channel_shuffle.py +7 -6
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +41 -4
- mindspore/nn/layer/conv.py +64 -28
- mindspore/nn/layer/dense.py +9 -8
- mindspore/nn/layer/embedding.py +27 -25
- mindspore/nn/layer/image.py +53 -46
- mindspore/nn/layer/math.py +97 -105
- mindspore/nn/layer/normalization.py +117 -86
- mindspore/nn/layer/padding.py +185 -95
- mindspore/nn/layer/pooling.py +817 -414
- mindspore/nn/layer/rnn_cells.py +10 -15
- mindspore/nn/layer/rnns.py +37 -38
- mindspore/nn/layer/thor_layer.py +11 -12
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +5 -4
- mindspore/nn/loss/loss.py +334 -199
- mindspore/nn/optim/ada_grad.py +6 -6
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +4 -5
- mindspore/nn/optim/adam.py +126 -62
- mindspore/nn/optim/adamax.py +3 -4
- mindspore/nn/optim/adasum.py +6 -6
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +67 -38
- mindspore/nn/optim/lamb.py +4 -5
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +43 -4
- mindspore/nn/optim/momentum.py +6 -5
- mindspore/nn/optim/optimizer.py +3 -1
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +1 -1
- mindspore/nn/optim/rprop.py +8 -9
- mindspore/nn/optim/sgd.py +19 -13
- mindspore/nn/optim/thor.py +10 -15
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +4 -4
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +9 -15
- mindspore/nn/probability/distribution/bernoulli.py +3 -3
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +5 -7
- mindspore/nn/probability/distribution/cauchy.py +3 -3
- mindspore/nn/probability/distribution/distribution.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +3 -3
- mindspore/nn/probability/distribution/half_normal.py +15 -11
- mindspore/nn/probability/distribution/laplace.py +16 -13
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/normal.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/student_t.py +20 -15
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +27 -10
- mindspore/nn/wrap/grad_reducer.py +2 -2
- mindspore/nn/wrap/loss_scale.py +40 -24
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +35 -30
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +22 -19
- mindspore/numpy/utils.py +1 -1
- mindspore/numpy/utils_const.py +108 -58
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +86 -117
- mindspore/ops/_grad/grad_base.py +23 -1
- mindspore/ops/_grad/grad_clip_ops.py +2 -3
- mindspore/ops/_grad/grad_comm_ops.py +34 -24
- mindspore/ops/_grad/grad_implementations.py +9 -45
- mindspore/ops/_grad/grad_inner_ops.py +47 -4
- mindspore/ops/_grad/grad_math_ops.py +142 -117
- mindspore/ops/_grad/grad_nn_ops.py +71 -165
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +7 -6
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
- mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
- mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
- mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -611
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_register_for_op.py +1 -0
- mindspore/ops/_utils/__init__.py +1 -2
- mindspore/ops/_utils/utils.py +19 -40
- mindspore/ops/_vmap/vmap_array_ops.py +116 -38
- mindspore/ops/_vmap/vmap_base.py +16 -9
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
- mindspore/ops/_vmap/vmap_image_ops.py +12 -5
- mindspore/ops/_vmap/vmap_math_ops.py +46 -5
- mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
- mindspore/ops/_vmap/vmap_random_ops.py +1 -1
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
- mindspore/ops/composite/__init__.py +7 -8
- mindspore/ops/composite/base.py +101 -47
- mindspore/ops/composite/math_ops.py +188 -158
- mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
- mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
- mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
- mindspore/ops/function/__init__.py +152 -8
- mindspore/ops/function/array_func.py +2555 -674
- mindspore/ops/function/clip_func.py +209 -13
- mindspore/ops/function/debug_func.py +2 -2
- mindspore/ops/function/grad/__init__.py +2 -1
- mindspore/ops/function/grad/grad_func.py +147 -62
- mindspore/ops/function/image_func.py +54 -38
- mindspore/ops/function/linalg_func.py +167 -16
- mindspore/ops/function/math_func.py +4849 -1492
- mindspore/ops/function/nn_func.py +2573 -988
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +3 -3
- mindspore/ops/function/random_func.py +790 -73
- mindspore/ops/function/sparse_func.py +98 -78
- mindspore/ops/function/sparse_unary_func.py +54 -53
- mindspore/ops/function/spectral_func.py +27 -24
- mindspore/ops/function/vmap_func.py +22 -2
- mindspore/ops/functional.py +97 -37
- mindspore/ops/op_info_register.py +70 -28
- mindspore/ops/operations/__init__.py +47 -14
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +5 -5
- mindspore/ops/operations/_grad_ops.py +276 -187
- mindspore/ops/operations/_inner_ops.py +319 -113
- mindspore/ops/operations/_ms_kernel.py +10 -8
- mindspore/ops/operations/_ocr_ops.py +9 -9
- mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
- mindspore/ops/operations/_quant_ops.py +137 -102
- mindspore/ops/operations/_rl_inner_ops.py +121 -60
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1004 -2
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +801 -466
- mindspore/ops/operations/comm_ops.py +51 -49
- mindspore/ops/operations/control_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +123 -44
- mindspore/ops/operations/debug_ops.py +24 -24
- mindspore/ops/operations/image_ops.py +240 -153
- mindspore/ops/operations/inner_ops.py +34 -50
- mindspore/ops/operations/linalg_ops.py +31 -9
- mindspore/ops/operations/math_ops.py +988 -757
- mindspore/ops/operations/nn_ops.py +965 -819
- mindspore/ops/operations/other_ops.py +51 -40
- mindspore/ops/operations/random_ops.py +204 -122
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +254 -93
- mindspore/ops/operations/spectral_ops.py +35 -3
- mindspore/ops/primitive.py +111 -9
- mindspore/parallel/_auto_parallel_context.py +189 -83
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +99 -7
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +7 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
- mindspore/parallel/_utils.py +1 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +37 -34
- mindspore/parallel/shard.py +17 -18
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +69 -47
- mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
- mindspore/profiler/parser/base_timeline_generator.py +49 -56
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
- mindspore/profiler/parser/hwts_log_parser.py +1 -1
- mindspore/profiler/parser/integrator.py +15 -14
- mindspore/profiler/parser/minddata_analyzer.py +2 -2
- mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +2 -1
- mindspore/profiler/profiling.py +218 -186
- mindspore/rewrite/__init__.py +3 -1
- mindspore/rewrite/api/node.py +1 -114
- mindspore/rewrite/api/node_type.py +3 -0
- mindspore/rewrite/api/pattern_engine.py +31 -1
- mindspore/rewrite/api/scoped_value.py +4 -4
- mindspore/rewrite/api/symbol_tree.py +3 -78
- mindspore/rewrite/api/tree_node_helper.py +1 -1
- mindspore/rewrite/ast_creator_register.py +1 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -2
- mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
- mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
- mindspore/rewrite/namespace.py +0 -2
- mindspore/rewrite/node.py +157 -11
- mindspore/rewrite/parsers/assign_parser.py +231 -53
- mindspore/rewrite/parsers/class_def_parser.py +187 -109
- mindspore/rewrite/parsers/for_parser.py +24 -14
- mindspore/rewrite/parsers/function_def_parser.py +21 -4
- mindspore/rewrite/parsers/if_parser.py +6 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +256 -133
- mindspore/rewrite/symbol_tree_builder.py +38 -1
- mindspore/run_check/_check_version.py +69 -63
- mindspore/run_check/run_check.py +2 -1
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +1 -1
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +273 -102
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +2 -2
- mindspore/train/callback/_checkpoint.py +3 -3
- mindspore/train/callback/_early_stop.py +3 -3
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +29 -31
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +3 -3
- mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
- mindspore/train/callback/_summary_collector.py +23 -16
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +15 -3
- mindspore/train/dataset_helper.py +10 -15
- mindspore/train/loss_scale_manager.py +8 -11
- mindspore/train/metrics/__init__.py +1 -1
- mindspore/train/metrics/bleu_score.py +1 -1
- mindspore/train/metrics/confusion_matrix.py +1 -1
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/dice.py +2 -2
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +4 -3
- mindspore/train/metrics/mean_surface_distance.py +2 -2
- mindspore/train/metrics/occlusion_sensitivity.py +1 -1
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +1 -1
- mindspore/train/metrics/recall.py +1 -1
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +45 -28
- mindspore/train/serialization.py +295 -188
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -13
- mindspore/train/train_thor/convert_utils.py +2 -2
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
- mindspore/compression/__init__.py +0 -19
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -515
- mindspore/compression/quant/__init__.py +0 -28
- mindspore/compression/quant/qat.py +0 -634
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -140
- mindspore/nn/probability/dpn/vae/vae.py +0 -124
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
- mindspore/ops/composite/array_ops.py +0 -241
- mindspore/ops/composite/clip_ops.py +0 -134
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -16,25 +16,26 @@
|
|
|
16
16
|
"""array_ops vmap impl."""
|
|
17
17
|
from __future__ import absolute_import
|
|
18
18
|
|
|
19
|
-
import numpy as np
|
|
20
19
|
import mindspore
|
|
21
20
|
import mindspore.numpy as mnp
|
|
22
21
|
from mindspore import ops
|
|
23
22
|
from mindspore.common import Tensor
|
|
23
|
+
from mindspore._c_expression import Tensor as Tensor_
|
|
24
24
|
from mindspore.ops import operations as P
|
|
25
25
|
from mindspore.ops import functional as F
|
|
26
|
-
from mindspore.ops import constexpr
|
|
26
|
+
from mindspore.ops.primitive import constexpr, _primexpr
|
|
27
27
|
from mindspore.ops.operations._grad_ops import MaskedSelectGrad
|
|
28
28
|
from mindspore.ops.operations import _grad_ops as G
|
|
29
29
|
from mindspore.ops.operations.array_ops import Fills, UniqueConsecutive, Col2Im, NonZero, IndexFill, \
|
|
30
30
|
TensorScatterElements
|
|
31
31
|
from mindspore.ops.operations.random_ops import RandomPoisson
|
|
32
|
+
from mindspore.ops.operations._inner_ops import DynamicBroadcastTo
|
|
32
33
|
from mindspore.ops.primitive import Primitive
|
|
33
34
|
from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim_at_front, \
|
|
34
35
|
_raise_value_error, _vmap_clone_prim, _handle_broadcasting, get_unsupported_dynamic_vmap_rule, _broadcast_by_axis, \
|
|
35
36
|
get_unop_vmap_rule, _get_reduce_out_dim, _get_reduce_batch_axis, \
|
|
36
37
|
_bdim_at_any
|
|
37
|
-
from mindspore.ops.
|
|
38
|
+
from mindspore.ops.function import _VmapGeneralRule
|
|
38
39
|
|
|
39
40
|
|
|
40
41
|
@vmap_rules_getters.register(P.NoRepeatNGram)
|
|
@@ -137,7 +138,7 @@ def get_arg_min_max_with_value_vmap_rule(prim, axis_size):
|
|
|
137
138
|
return vmap_rule
|
|
138
139
|
|
|
139
140
|
|
|
140
|
-
@
|
|
141
|
+
@_primexpr
|
|
141
142
|
def _get_prefix(indices_shape, axis_size, indices_dtype):
|
|
142
143
|
"""
|
|
143
144
|
Generate prefix by indices shape, whose -1 axis value is the index value of axis 0.
|
|
@@ -147,14 +148,16 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
|
|
|
147
148
|
the generated prefix is a Tensor([[[0], [0]],
|
|
148
149
|
[[1], [1]]])
|
|
149
150
|
"""
|
|
150
|
-
|
|
151
|
-
|
|
151
|
+
def _check(indices_shape):
|
|
152
|
+
if not indices_shape:
|
|
153
|
+
raise ValueError("indices_shape is empty in _get_prefix.")
|
|
152
154
|
|
|
155
|
+
_check(indices_shape)
|
|
153
156
|
indices_len = len(indices_shape)
|
|
154
|
-
|
|
155
157
|
if indices_len == 1:
|
|
156
|
-
prefix =
|
|
157
|
-
|
|
158
|
+
prefix = P.Range()(Tensor(0, indices_dtype), P.Fill()(
|
|
159
|
+
indices_dtype, (), axis_size), Tensor(1, indices_dtype))
|
|
160
|
+
return prefix
|
|
158
161
|
|
|
159
162
|
indices_end = indices_len - 1
|
|
160
163
|
prefix_shape = ()
|
|
@@ -169,8 +172,9 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
|
|
|
169
172
|
else:
|
|
170
173
|
expand_shape = expand_shape + (1,)
|
|
171
174
|
|
|
172
|
-
prefix =
|
|
173
|
-
|
|
175
|
+
prefix = P.BroadcastTo(prefix_shape)(P.Reshape()(P.Range()(Tensor(
|
|
176
|
+
0, indices_dtype), Tensor(axis_size, indices_dtype), Tensor(1, indices_dtype)), expand_shape))
|
|
177
|
+
return prefix
|
|
174
178
|
|
|
175
179
|
|
|
176
180
|
@vmap_rules_getters.register(P.Transpose)
|
|
@@ -179,7 +183,7 @@ def get_transpose_vmap_rule(prim, axis_size):
|
|
|
179
183
|
if isinstance(prim, str):
|
|
180
184
|
prim = Primitive(prim)
|
|
181
185
|
|
|
182
|
-
@
|
|
186
|
+
@_primexpr
|
|
183
187
|
def _get_transpose_batch_perm(dim, perm, x_rank):
|
|
184
188
|
"""Generate batch_perm based on the original perm of transpose operation and dim of the input."""
|
|
185
189
|
if dim < 0:
|
|
@@ -223,7 +227,7 @@ def get_tile_vmap_rule(prim, axis_size):
|
|
|
223
227
|
if isinstance(prim, str):
|
|
224
228
|
prim = Primitive(prim)
|
|
225
229
|
|
|
226
|
-
@
|
|
230
|
+
@_primexpr
|
|
227
231
|
def _get_batch_multiples(input_shape, dim, multiples):
|
|
228
232
|
input_ndim = len(input_shape)
|
|
229
233
|
multiples_ndim = len(multiples)
|
|
@@ -352,8 +356,13 @@ def get_unstack_vmap_rule(prim, axis_size):
|
|
|
352
356
|
def get_reshape_vmap_rule(prim, axis_size):
|
|
353
357
|
"""VmapRule for `Reshape` operation."""
|
|
354
358
|
|
|
355
|
-
|
|
359
|
+
|
|
360
|
+
@_primexpr
|
|
356
361
|
def get_batch_shape(x_shape, x_dim, target_shape, axis_size):
|
|
362
|
+
def _check(neg_index, target_shape):
|
|
363
|
+
if neg_index != -1:
|
|
364
|
+
raise ValueError(f'The shape can only has one -1 at most, but {target_shape}.')
|
|
365
|
+
|
|
357
366
|
if x_dim == 0:
|
|
358
367
|
return (axis_size,) + target_shape, 0, False
|
|
359
368
|
|
|
@@ -364,19 +373,21 @@ def get_reshape_vmap_rule(prim, axis_size):
|
|
|
364
373
|
dim_prod = 1
|
|
365
374
|
for i, shp_i in enumerate(target_shape):
|
|
366
375
|
if shp_i == -1:
|
|
367
|
-
|
|
368
|
-
raise ValueError(f'The shape can only has one -1 at most, but {target_shape}.')
|
|
376
|
+
_check(neg_index, target_shape)
|
|
369
377
|
neg_index = i
|
|
370
378
|
else:
|
|
371
379
|
dim_prod *= shp_i
|
|
372
|
-
arr_prod =
|
|
380
|
+
arr_prod = 1
|
|
381
|
+
for i in x_shape:
|
|
382
|
+
arr_prod *= i
|
|
373
383
|
target_shape_list = list(target_shape)
|
|
374
384
|
if neg_index != -1:
|
|
375
385
|
neg_index_size = int(arr_prod // (dim_prod * axis_size))
|
|
376
386
|
target_shape_list[neg_index] = neg_index_size
|
|
377
387
|
|
|
378
|
-
arr_prod_before_dim =
|
|
379
|
-
|
|
388
|
+
arr_prod_before_dim = 1
|
|
389
|
+
for i in x_shape[:x_dim]:
|
|
390
|
+
arr_prod_before_dim *= i
|
|
380
391
|
dim_prod = 1
|
|
381
392
|
for i, shp_i in enumerate(target_shape_list, start=1):
|
|
382
393
|
dim_prod *= shp_i
|
|
@@ -421,7 +432,7 @@ def get_reverse_sequence_vmap_rule(prim, axis_size):
|
|
|
421
432
|
batch_dim = prim.batch_dim_
|
|
422
433
|
seq_dim = prim.seq_dim_
|
|
423
434
|
|
|
424
|
-
@
|
|
435
|
+
@_primexpr
|
|
425
436
|
def get_batch_seq_dim(dim, batch_dim_, seq_dim_):
|
|
426
437
|
if dim is None:
|
|
427
438
|
batch_dim_ += 1
|
|
@@ -437,7 +448,7 @@ def get_reverse_sequence_vmap_rule(prim, axis_size):
|
|
|
437
448
|
seq_dim_ += 1
|
|
438
449
|
return batch_dim_, seq_dim_
|
|
439
450
|
|
|
440
|
-
@
|
|
451
|
+
@_primexpr
|
|
441
452
|
def get_seq_dim(dim, batch_dim_, seq_dim_):
|
|
442
453
|
if dim is None:
|
|
443
454
|
return seq_dim_
|
|
@@ -557,20 +568,19 @@ def get_scatter_nd_vmap_rule(prim, axis_size):
|
|
|
557
568
|
Reshape the output tensor to `[10, 6, 4, 5]`
|
|
558
569
|
"""
|
|
559
570
|
|
|
560
|
-
@
|
|
571
|
+
@_primexpr
|
|
561
572
|
def _refine_shape(shape, bdim_size):
|
|
562
573
|
offset = shape[0]
|
|
563
574
|
return (bdim_size * shape[0],) + tuple(shape[1:]), offset, (bdim_size,) + tuple(shape)
|
|
564
575
|
|
|
565
|
-
@
|
|
576
|
+
@_primexpr
|
|
566
577
|
def _gen_indices_offset(shape, offset):
|
|
567
578
|
# original rank(indices.shape) is required >= 2, so indices with batch dim's rank >= 3.
|
|
568
|
-
shape =
|
|
569
|
-
val =
|
|
570
|
-
val = np.reshape(val, (shape[0], shape[-1]))
|
|
579
|
+
shape = (shape[0],) + (1,) * (len(shape) - 2) + (shape[-1],)
|
|
580
|
+
val = P.Zeros()((shape[0], shape[-1]), mindspore.int32)
|
|
571
581
|
for i in range(shape[0]):
|
|
572
582
|
val[i, 0] = i * offset
|
|
573
|
-
return
|
|
583
|
+
return P.Reshape()(val, shape)
|
|
574
584
|
|
|
575
585
|
if isinstance(prim, str):
|
|
576
586
|
prim = Primitive(prim)
|
|
@@ -591,7 +601,7 @@ def get_scatter_nd_vmap_rule(prim, axis_size):
|
|
|
591
601
|
indices_shape = F.shape(indices)
|
|
592
602
|
indices_dtype = F.dtype(indices)
|
|
593
603
|
offset_val = _gen_indices_offset(indices_shape, offset)
|
|
594
|
-
indices_offset =
|
|
604
|
+
indices_offset = P.Cast()(offset_val, indices_dtype)
|
|
595
605
|
new_indices = P.Add()(indices, indices_offset)
|
|
596
606
|
out = prim(new_indices, updates, new_shape)
|
|
597
607
|
real_out = P.Reshape()(out, out_shape)
|
|
@@ -839,6 +849,62 @@ def get_fill_vmap_rule(prim, axis_size):
|
|
|
839
849
|
return vmap_rule
|
|
840
850
|
|
|
841
851
|
|
|
852
|
+
@constexpr
|
|
853
|
+
def to_tensor_with_type(x, type):
|
|
854
|
+
"""x to Tensor with type"""
|
|
855
|
+
return Tensor(x, type)
|
|
856
|
+
|
|
857
|
+
|
|
858
|
+
@vmap_rules_getters.register(P.FillV2)
|
|
859
|
+
def get_fill_v2_vmap_rule(prim, axis_size):
|
|
860
|
+
"""VmapRule for `FillV2` operation."""
|
|
861
|
+
if isinstance(prim, str):
|
|
862
|
+
prim = Primitive(prim)
|
|
863
|
+
|
|
864
|
+
def vmap_rule(shape_bdim, value_bdim):
|
|
865
|
+
is_all_none, result = vmap_general_preprocess(prim, shape_bdim, value_bdim)
|
|
866
|
+
if is_all_none:
|
|
867
|
+
return result
|
|
868
|
+
|
|
869
|
+
value_shape, shape_dim = shape_bdim
|
|
870
|
+
if shape_dim is not None:
|
|
871
|
+
_raise_value_error(
|
|
872
|
+
"The source axis of `shape` in `P.FillV2` must be None, but got {}."
|
|
873
|
+
.format(shape_dim))
|
|
874
|
+
|
|
875
|
+
value, vdim = value_bdim
|
|
876
|
+
value_rank = F.rank(value)
|
|
877
|
+
if value_rank != 1 or vdim != 0:
|
|
878
|
+
_raise_value_error(
|
|
879
|
+
"The `value` in `P.FillV2` must be constant value, thus the value only "
|
|
880
|
+
"can be rank: 1 with source axis: 0 in vmap scope, but got value rank: "
|
|
881
|
+
"{} with source axis: {}.".format(value_rank, vdim))
|
|
882
|
+
value = F.reshape(value, (axis_size,) + (1,) * len(value_shape))
|
|
883
|
+
|
|
884
|
+
out = None
|
|
885
|
+
if isinstance(value_shape, (Tensor_, Tensor)):
|
|
886
|
+
value_shape_rank = F.rank(value_shape)
|
|
887
|
+
if value_shape_rank != 1:
|
|
888
|
+
_raise_value_error(
|
|
889
|
+
"The `shape` in `P.FillV2` must be 1-D tensor, thus the shape only "
|
|
890
|
+
"can be rank: 1, but got shape rank: "
|
|
891
|
+
"{}.".format(value_shape_rank))
|
|
892
|
+
axis_size_tensor = to_tensor_with_type((axis_size,),
|
|
893
|
+
F.dtype(value_shape))
|
|
894
|
+
broad_cast_shape = F.concat((axis_size_tensor, value_shape))
|
|
895
|
+
out = DynamicBroadcastTo()(value, broad_cast_shape)
|
|
896
|
+
elif isinstance(value_shape, tuple):
|
|
897
|
+
out = P.BroadcastTo((axis_size,) + value_shape)(value)
|
|
898
|
+
else:
|
|
899
|
+
_raise_value_error(
|
|
900
|
+
f"For `P.FillV2`, the input `shape` should be Tuple or Tensor, but got `shape`: {value_shape}."
|
|
901
|
+
)
|
|
902
|
+
|
|
903
|
+
return out, 0
|
|
904
|
+
|
|
905
|
+
return vmap_rule
|
|
906
|
+
|
|
907
|
+
|
|
842
908
|
@vmap_rules_getters.register(Fills)
|
|
843
909
|
def get_fills_vmap_rule(prim, axis_size):
|
|
844
910
|
"""VmapRule for `Fills` operation."""
|
|
@@ -1414,6 +1480,7 @@ def get_meshgrid_vmap_rule(prim, axis_size):
|
|
|
1414
1480
|
"The input number of P.Meshgrid must be greater than 1.")
|
|
1415
1481
|
|
|
1416
1482
|
output_shape = []
|
|
1483
|
+
ones_shape = []
|
|
1417
1484
|
for each_arg in args:
|
|
1418
1485
|
x, bdim = each_arg
|
|
1419
1486
|
if bdim is None:
|
|
@@ -1424,19 +1491,30 @@ def get_meshgrid_vmap_rule(prim, axis_size):
|
|
|
1424
1491
|
_raise_value_error(
|
|
1425
1492
|
"Each input of Meshgrid must be 1D, but got {}.".format(F.rank(x) - 1))
|
|
1426
1493
|
output_shape.append(F.shape(x)[-1])
|
|
1494
|
+
ones_shape.append(1)
|
|
1427
1495
|
output_shape.insert(0, axis_size)
|
|
1496
|
+
ones_shape.insert(0, axis_size)
|
|
1428
1497
|
|
|
1429
1498
|
if indexing == "xy":
|
|
1430
1499
|
output_shape[1], output_shape[2] = output_shape[2], output_shape[1]
|
|
1431
|
-
|
|
1432
1500
|
shape = tuple(output_shape)
|
|
1501
|
+
|
|
1502
|
+
input_0, _ = args[0]
|
|
1503
|
+
dtype = F.dtype(input_0)
|
|
1504
|
+
ones_tensor = F.fill(dtype, shape, 1)
|
|
1505
|
+
|
|
1506
|
+
index = 0
|
|
1433
1507
|
vals_out_tuple = ()
|
|
1434
1508
|
for each_arg in args:
|
|
1435
1509
|
x, bdim = each_arg
|
|
1436
1510
|
x = _bdim_at_front(x, bdim, axis_size)
|
|
1437
|
-
|
|
1438
|
-
|
|
1511
|
+
shape_index = (1 - index) if (index <= 1 and indexing == "xy") else index
|
|
1512
|
+
ones_shape[shape_index + 1] = output_shape[shape_index + 1]
|
|
1513
|
+
x = P.Reshape()(x, tuple(ones_shape))
|
|
1514
|
+
output = P.Mul()(x, ones_tensor)
|
|
1439
1515
|
vals_out_tuple = vals_out_tuple + ((output, 0),)
|
|
1516
|
+
ones_shape[shape_index + 1] = 1
|
|
1517
|
+
index = index + 1
|
|
1440
1518
|
|
|
1441
1519
|
return vals_out_tuple
|
|
1442
1520
|
|
|
@@ -1480,7 +1558,7 @@ def get_gather_vmap_rule(prim, axis_size):
|
|
|
1480
1558
|
else:
|
|
1481
1559
|
prim_name = prim.name
|
|
1482
1560
|
|
|
1483
|
-
@
|
|
1561
|
+
@_primexpr
|
|
1484
1562
|
def process_axis(axis, x_shape_size, has_xdim: bool, has_idim: bool):
|
|
1485
1563
|
if has_xdim and has_idim:
|
|
1486
1564
|
if axis < 0:
|
|
@@ -1494,7 +1572,7 @@ def get_gather_vmap_rule(prim, axis_size):
|
|
|
1494
1572
|
|
|
1495
1573
|
return axis
|
|
1496
1574
|
|
|
1497
|
-
@
|
|
1575
|
+
@_primexpr
|
|
1498
1576
|
def get_x_dst_shape(x_shape, axis):
|
|
1499
1577
|
target_axis_size = x_shape[axis + 1]
|
|
1500
1578
|
x_dst_shape = x_shape[0:axis] + (axis_size * target_axis_size,) + x_shape[axis + 2:]
|
|
@@ -1694,7 +1772,7 @@ def get_data_format_dim_map_vmap_rule(prim, axis_size):
|
|
|
1694
1772
|
def get_expand_dims_vmap_rule(prim, axis_size):
|
|
1695
1773
|
"""VmapRule for `ExpandDims`."""
|
|
1696
1774
|
|
|
1697
|
-
@
|
|
1775
|
+
@_primexpr
|
|
1698
1776
|
def process_axis(axis, rank, x_dim):
|
|
1699
1777
|
if axis < 0:
|
|
1700
1778
|
axis += rank
|
|
@@ -1788,7 +1866,7 @@ def get_squeeze_vmap_rule(prim, axis_size):
|
|
|
1788
1866
|
else:
|
|
1789
1867
|
prim_axis = None
|
|
1790
1868
|
|
|
1791
|
-
@
|
|
1869
|
+
@_primexpr
|
|
1792
1870
|
def move_axis(axes):
|
|
1793
1871
|
new_axis = ()
|
|
1794
1872
|
for axis in axes:
|
|
@@ -1798,7 +1876,7 @@ def get_squeeze_vmap_rule(prim, axis_size):
|
|
|
1798
1876
|
new_axis = new_axis + (axis + 1,)
|
|
1799
1877
|
return new_axis
|
|
1800
1878
|
|
|
1801
|
-
@
|
|
1879
|
+
@_primexpr
|
|
1802
1880
|
def generate_all_axis_except_first(x_rank):
|
|
1803
1881
|
new_axis = ()
|
|
1804
1882
|
for i in range(1, x_rank, 1):
|
|
@@ -1842,7 +1920,7 @@ def get_stridedslice_vmap_rule(prim, axis_size):
|
|
|
1842
1920
|
batch_stridedslice = P.StridedSlice(new_begin_mask, new_end_mask, new_ellipsis_mask, new_new_axis_mask, \
|
|
1843
1921
|
new_shrink_axis_mask)
|
|
1844
1922
|
|
|
1845
|
-
@
|
|
1923
|
+
@_primexpr
|
|
1846
1924
|
def get_new_begin_end_strided(begin, end, strided):
|
|
1847
1925
|
new_begin = (0,) + begin
|
|
1848
1926
|
new_end = (0,) + end
|
|
@@ -1883,7 +1961,7 @@ def get_stridedslice_grad_vmap_rule(prim, axis_size):
|
|
|
1883
1961
|
batch_stridedslice_grad = G.StridedSliceGrad(new_begin_mask, new_end_mask, new_ellipsis_mask, new_new_axis_mask, \
|
|
1884
1962
|
new_shrink_axis_mask)
|
|
1885
1963
|
|
|
1886
|
-
@
|
|
1964
|
+
@_primexpr
|
|
1887
1965
|
def get_new_xshape_begin_end_strided(xshape, begin, end, strided):
|
|
1888
1966
|
new_xshape = (axis_size,) + xshape
|
|
1889
1967
|
new_begin = (0,) + begin
|
mindspore/ops/_vmap/vmap_base.py
CHANGED
|
@@ -21,11 +21,12 @@ from mindspore.common import Tensor
|
|
|
21
21
|
from mindspore.ops import operations as P
|
|
22
22
|
from mindspore.ops import functional as F
|
|
23
23
|
from mindspore.ops import constexpr
|
|
24
|
+
from mindspore.ops.primitive import _primexpr
|
|
24
25
|
from mindspore.ops.operations import math_ops
|
|
25
26
|
from mindspore.ops.operations import _grad_ops as G
|
|
26
27
|
from mindspore.ops.operations import nn_ops as nps
|
|
27
|
-
from mindspore.ops.
|
|
28
|
-
from mindspore.ops.primitive import Primitive
|
|
28
|
+
from mindspore.ops.function import _VmapGeneralPreprocess
|
|
29
|
+
from mindspore.ops.primitive import Primitive, _PrimitiveC
|
|
29
30
|
from mindspore.ops.operations.random_ops import UniformCandidateSampler, RandomShuffle
|
|
30
31
|
from mindspore.ops._grad.grad_base import BpropRegistry as VmapRuleRegistry
|
|
31
32
|
|
|
@@ -41,7 +42,7 @@ def get_vmap_rule(prim, axis_size):
|
|
|
41
42
|
return None
|
|
42
43
|
|
|
43
44
|
|
|
44
|
-
@
|
|
45
|
+
@_primexpr
|
|
45
46
|
def _get_broadcast_shape_with_front_axis(x_shape, y_shape):
|
|
46
47
|
""" Explicitly matched with the broadcast shape, that is, 1 is added to the broadcast position. """
|
|
47
48
|
x_len = len(x_shape)
|
|
@@ -86,7 +87,7 @@ def _handle_broadcasting(x, x_shape, y_shape):
|
|
|
86
87
|
return F.reshape(x, broadcast_shape)
|
|
87
88
|
|
|
88
89
|
|
|
89
|
-
@
|
|
90
|
+
@_primexpr
|
|
90
91
|
def _get_broadcasting_with_front_axis_additional_axis(x_shape, y_shape):
|
|
91
92
|
""" Get the axes that are inserted after broadcasting.
|
|
92
93
|
Args:
|
|
@@ -129,15 +130,19 @@ def _raise_value_error(info, param=None):
|
|
|
129
130
|
raise ValueError(info + f"{param}")
|
|
130
131
|
|
|
131
132
|
|
|
132
|
-
@
|
|
133
|
+
@_primexpr
|
|
133
134
|
def _get_broadcast_shape(x_shape, dst, axis_size):
|
|
134
135
|
"""Get the target shape for broadcast array."""
|
|
136
|
+
def _check(dst, broadcast_ndim):
|
|
137
|
+
if dst < -broadcast_ndim or dst >= broadcast_ndim:
|
|
138
|
+
_raise_value_error("Destination axis {} is out of bounds for array of dimension"
|
|
139
|
+
" [{}, {}).".format(dst, -broadcast_ndim, broadcast_ndim))
|
|
140
|
+
|
|
135
141
|
x_ndim = len(x_shape)
|
|
136
142
|
broadcast_ndim = x_ndim + 1
|
|
137
143
|
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
" [{}, {}).".format(dst, -broadcast_ndim, broadcast_ndim))
|
|
144
|
+
_check(dst, broadcast_ndim)
|
|
145
|
+
|
|
141
146
|
if dst < 0:
|
|
142
147
|
dst = broadcast_ndim + dst
|
|
143
148
|
|
|
@@ -420,6 +425,8 @@ def _vmap_clone_prim(prim):
|
|
|
420
425
|
"""
|
|
421
426
|
Cloning a new primitive object same as `prim`.
|
|
422
427
|
"""
|
|
428
|
+
if isinstance(prim, _PrimitiveC):
|
|
429
|
+
return _PrimitiveC(prim.name, prim.attrs)
|
|
423
430
|
new_ops = _ops_vmap_clone_prim_dict.get(prim.name, None)
|
|
424
431
|
if new_ops is None:
|
|
425
432
|
raise ValueError("Failed to get the primitive object of {} from `_ops_vmap_clone_prim_dict`. Please register "
|
|
@@ -437,7 +444,7 @@ def _vmap_clone_prim(prim):
|
|
|
437
444
|
return cloned
|
|
438
445
|
|
|
439
446
|
|
|
440
|
-
@
|
|
447
|
+
@_primexpr
|
|
441
448
|
def _get_reduce_batch_axis(axis, x_dim, x_ndim):
|
|
442
449
|
"""get batch_axis for reduce* operation."""
|
|
443
450
|
# For axis, it's value in Union[int, list, tuple]
|
|
@@ -16,9 +16,9 @@
|
|
|
16
16
|
"""convolution vmap impl"""
|
|
17
17
|
from __future__ import absolute_import
|
|
18
18
|
|
|
19
|
-
import numpy as np
|
|
20
19
|
import mindspore.numpy as mnp
|
|
21
20
|
from mindspore.ops import constexpr
|
|
21
|
+
from mindspore.ops.primitive import _primexpr
|
|
22
22
|
from mindspore.ops import operations as P
|
|
23
23
|
from mindspore.ops import functional as F
|
|
24
24
|
from mindspore.ops.operations import nn_ops as nps
|
|
@@ -142,7 +142,7 @@ def get_conv3d_backprop_filter_vmap_rule(prim, axis_size):
|
|
|
142
142
|
return vmap_rule
|
|
143
143
|
|
|
144
144
|
|
|
145
|
-
@
|
|
145
|
+
@_primexpr
|
|
146
146
|
def _get_reshape_src_dim(data_dim, cmp_dim):
|
|
147
147
|
"""Get source dim for reshape"""
|
|
148
148
|
if data_dim > cmp_dim:
|
|
@@ -154,7 +154,7 @@ def _get_reshape_src_dim(data_dim, cmp_dim):
|
|
|
154
154
|
return expand_dim, merge_dim
|
|
155
155
|
|
|
156
156
|
|
|
157
|
-
@
|
|
157
|
+
@_primexpr
|
|
158
158
|
def _get_merge_shape(src_dim, dst_dim, shape):
|
|
159
159
|
"""Get new shape for merging the src_dim and dst_dim. The dst_dim is the value after removing src_dim."""
|
|
160
160
|
new_shape = [shape[i] for i in range(len(shape)) if i != src_dim]
|
|
@@ -171,13 +171,10 @@ def _reshape_merge_dims(src_dim, dst_dim, target):
|
|
|
171
171
|
return output, new_shape
|
|
172
172
|
|
|
173
173
|
|
|
174
|
-
@
|
|
174
|
+
@_primexpr
|
|
175
175
|
def _get_expand_shape(src_dim, dst_size, shape, prim_name):
|
|
176
176
|
"""Get new shape for splitting src_dim into dst_size parts."""
|
|
177
|
-
dst_size2
|
|
178
|
-
if remainder != 0:
|
|
179
|
-
_raise_value_error("The remainder of {} / {} should be 0, "
|
|
180
|
-
"but got {} in {}.".format(shape[src_dim], dst_size, remainder, prim_name))
|
|
177
|
+
dst_size2 = shape[src_dim] // dst_size
|
|
181
178
|
new_shape = list(shape)
|
|
182
179
|
new_shape[src_dim:(src_dim + 1)] = [dst_size, dst_size2]
|
|
183
180
|
return tuple(new_shape)
|
|
@@ -190,7 +187,7 @@ def _reshape_expand_dims(src_dim, dst_size, target, prim_name):
|
|
|
190
187
|
return F.reshape(target, new_shape)
|
|
191
188
|
|
|
192
189
|
|
|
193
|
-
@
|
|
190
|
+
@_primexpr
|
|
194
191
|
def _get_new_size_by_index(input_size, batch_size, index):
|
|
195
192
|
"""Get the new size of input_size by multiplying input_size[index] by batch_size."""
|
|
196
193
|
new_size = ()
|
|
@@ -201,7 +198,7 @@ def _get_new_size_by_index(input_size, batch_size, index):
|
|
|
201
198
|
return tuple(new_size)
|
|
202
199
|
|
|
203
200
|
|
|
204
|
-
@
|
|
201
|
+
@_primexpr
|
|
205
202
|
def _update_group_attr(prim, groups, batch_size):
|
|
206
203
|
"""Set new value for 'group' attribute of the convolution primitive."""
|
|
207
204
|
group = groups * batch_size
|
|
@@ -17,9 +17,9 @@
|
|
|
17
17
|
from __future__ import absolute_import
|
|
18
18
|
|
|
19
19
|
from mindspore.ops import functional as F
|
|
20
|
-
from mindspore.ops import
|
|
20
|
+
from mindspore.ops.primitive import _primexpr
|
|
21
21
|
from mindspore.ops.operations import _grad_ops as G
|
|
22
|
-
from mindspore.ops.
|
|
22
|
+
from mindspore.ops.function import _VmapGeneralRule
|
|
23
23
|
from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim_at_front, \
|
|
24
24
|
_handle_broadcasting, get_unary_grad_vmap_rule, _get_broadcasting_with_front_axis_additional_axis
|
|
25
25
|
|
|
@@ -36,7 +36,7 @@ def get_broadcast_binary_op_grad_vmap_rule(prim, axis_size):
|
|
|
36
36
|
if isinstance(prim, str):
|
|
37
37
|
prim = broadcast_binary_op_grad_map.get(prim)()
|
|
38
38
|
|
|
39
|
-
@
|
|
39
|
+
@_primexpr
|
|
40
40
|
def get_longest_shape(x_shape, y_shape, g_shape):
|
|
41
41
|
x_rank = len(x_shape)
|
|
42
42
|
y_rank = len(y_shape)
|
|
@@ -148,7 +148,7 @@ def get_median_grad_vmap_rule(prim, axis_size):
|
|
|
148
148
|
axis = prim.axis
|
|
149
149
|
keep_dims = prim.keep_dims
|
|
150
150
|
|
|
151
|
-
@
|
|
151
|
+
@_primexpr
|
|
152
152
|
def trans_grad_axis(axis, rank, dim, keep_dims):
|
|
153
153
|
if axis < 0:
|
|
154
154
|
axis += rank - 1
|
|
@@ -22,8 +22,9 @@ import mindspore.numpy as mnp
|
|
|
22
22
|
from mindspore.ops.operations import _grad_ops as G
|
|
23
23
|
from mindspore.ops import functional as F
|
|
24
24
|
from mindspore.ops import constexpr
|
|
25
|
+
from mindspore.ops.primitive import _primexpr
|
|
25
26
|
from mindspore.ops.primitive import Primitive
|
|
26
|
-
from mindspore.ops.
|
|
27
|
+
from mindspore.ops.function import _VmapGeneralRule
|
|
27
28
|
from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _raise_value_error, \
|
|
28
29
|
_bdim_at_front, _vmap_clone_prim, _vmap_update_prim_attr, _bdim_at_any, _handle_broadcasting
|
|
29
30
|
|
|
@@ -38,7 +39,7 @@ def get_nll_loss_grad_vmap_rule(prim, axis_size):
|
|
|
38
39
|
2. And weight only support shape as (C,), while total_weight should be a scalar.
|
|
39
40
|
"""
|
|
40
41
|
|
|
41
|
-
@
|
|
42
|
+
@_primexpr
|
|
42
43
|
def _get_reshape_shape(shape, keep_dim=0):
|
|
43
44
|
new_batch_size = reduce(
|
|
44
45
|
lambda x, y: x * y, shape if keep_dim == 0 else shape[:-keep_dim])
|
|
@@ -397,8 +398,9 @@ def get_batchnorm_grad_vmap_rule(prim, axis_size):
|
|
|
397
398
|
|
|
398
399
|
@vmap_rules_getters.register(G.MaxPoolGradGrad)
|
|
399
400
|
@vmap_rules_getters.register(G.MaxPoolGradGradWithArgmax)
|
|
401
|
+
@vmap_rules_getters.register(G.MaxPoolGradWithArgmaxV2)
|
|
400
402
|
def get_maxpool_grad_grad_vmap_rule(prim, axis_size):
|
|
401
|
-
"""VmapRule for `MaxPoolGradGrad` and `
|
|
403
|
+
"""VmapRule for `MaxPoolGradGrad`, `MaxPoolGradGradWithArgmax` and `MaxPoolGradWithArgmaxV2`."""
|
|
402
404
|
chw_reverse_index = -3
|
|
403
405
|
|
|
404
406
|
def vmap_rule(in0_bdim, in1_bdim, in2_bdim):
|
|
@@ -557,7 +559,7 @@ def get_layernormgrad_vmap_rule(prim, axis_size):
|
|
|
557
559
|
return prim_attr_axis
|
|
558
560
|
return prim_attr_axis + 1
|
|
559
561
|
|
|
560
|
-
@
|
|
562
|
+
@_primexpr
|
|
561
563
|
def get_batch_params_reduce_axes(begin_params_axis, x_shape):
|
|
562
564
|
if begin_params_axis < 0:
|
|
563
565
|
x_rank = len(x_shape)
|
|
@@ -565,7 +567,7 @@ def get_layernormgrad_vmap_rule(prim, axis_size):
|
|
|
565
567
|
batch_params_reduce_axes = tuple(range(1, begin_params_axis))
|
|
566
568
|
return batch_params_reduce_axes
|
|
567
569
|
|
|
568
|
-
@
|
|
570
|
+
@_primexpr
|
|
569
571
|
def get_logical_shape(var_shape):
|
|
570
572
|
return var_shape[1:]
|
|
571
573
|
|
|
@@ -16,10 +16,12 @@
|
|
|
16
16
|
"""image_ops vmap impl."""
|
|
17
17
|
from __future__ import absolute_import
|
|
18
18
|
|
|
19
|
-
import
|
|
19
|
+
import numpy as np
|
|
20
|
+
from mindspore import Tensor
|
|
20
21
|
from mindspore.ops import functional as F
|
|
21
22
|
from mindspore.ops.operations import _grad_ops as G
|
|
22
23
|
from mindspore.ops.operations import image_ops as IMG
|
|
24
|
+
from mindspore.ops import constexpr
|
|
23
25
|
from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim_at_front, \
|
|
24
26
|
_raise_value_error
|
|
25
27
|
|
|
@@ -90,6 +92,13 @@ def get_resize_grad_dynamic_rule(prim, axis_size):
|
|
|
90
92
|
def get_crop_and_resize_vmap_rule(prim, axis_size):
|
|
91
93
|
"""VmapRule for `CropAndResize` operation."""
|
|
92
94
|
|
|
95
|
+
@constexpr
|
|
96
|
+
def get_box_indices_offsets(axis_size, batch_size, num_boxes):
|
|
97
|
+
offsets = np.arange(0, axis_size * batch_size, batch_size).astype(np.int32)
|
|
98
|
+
offsets = np.reshape(offsets, (axis_size, 1))
|
|
99
|
+
offsets = np.broadcast_to(offsets, (axis_size, num_boxes))
|
|
100
|
+
return Tensor(offsets)
|
|
101
|
+
|
|
93
102
|
def vmap_rule(x_bdim, boxes_bdim, box_indices_bdim, crop_size_bdim):
|
|
94
103
|
is_all_none, result = vmap_general_preprocess(x_bdim, boxes_bdim, box_indices_bdim, crop_size_bdim)
|
|
95
104
|
if is_all_none:
|
|
@@ -115,10 +124,8 @@ def get_crop_and_resize_vmap_rule(prim, axis_size):
|
|
|
115
124
|
x = _bdim_at_front(x, x_dim, axis_size)
|
|
116
125
|
x_shape = F.shape(x)
|
|
117
126
|
x = F.reshape(x, (-1,) + x_shape[2:])
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
counts = F.broadcast_to(counts, (axis_size, num_boxes))
|
|
121
|
-
box_indices = F.add(box_indices, counts)
|
|
127
|
+
offsets = get_box_indices_offsets(axis_size, x_shape[1], num_boxes)
|
|
128
|
+
box_indices = F.add(box_indices, offsets)
|
|
122
129
|
box_indices = F.reshape(box_indices, (-1,))
|
|
123
130
|
out = prim(x, boxes, box_indices, crop_size)
|
|
124
131
|
|