mindspore 2.0.0a0__cp39-cp39-win_amd64.whl → 2.0.0rc1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -2
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +16 -1
- mindspore/_extends/parse/parser.py +107 -22
- mindspore/_extends/parse/resources.py +0 -7
- mindspore/_extends/parse/standard_method.py +885 -413
- mindspore/amp.py +52 -57
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +38 -20
- mindspore/boost/dim_reduce.py +3 -3
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -6
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +41 -7
- mindspore/common/api.py +215 -141
- mindspore/common/dtype.py +8 -1
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +4 -2
- mindspore/common/jit_config.py +17 -13
- mindspore/common/mutable.py +33 -13
- mindspore/common/parameter.py +23 -21
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +62 -41
- mindspore/common/tensor.py +852 -1154
- mindspore/communication/__init__.py +2 -2
- mindspore/communication/_comm_helper.py +11 -4
- mindspore/communication/management.py +22 -21
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +201 -23
- mindspore/dataset/__init__.py +6 -6
- mindspore/dataset/audio/__init__.py +7 -7
- mindspore/dataset/audio/transforms.py +670 -30
- mindspore/dataset/audio/utils.py +47 -4
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +210 -14
- mindspore/dataset/core/validator_helpers.py +2 -2
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +322 -66
- mindspore/dataset/engine/datasets_audio.py +80 -76
- mindspore/dataset/engine/datasets_standard_format.py +51 -38
- mindspore/dataset/engine/datasets_text.py +232 -118
- mindspore/dataset/engine/datasets_user_defined.py +41 -17
- mindspore/dataset/engine/datasets_vision.py +746 -225
- mindspore/dataset/engine/graphdata.py +75 -10
- mindspore/dataset/engine/iterators.py +45 -5
- mindspore/dataset/engine/offload.py +48 -28
- mindspore/dataset/engine/validators.py +117 -8
- mindspore/dataset/text/__init__.py +6 -5
- mindspore/dataset/text/transforms.py +86 -3
- mindspore/dataset/text/utils.py +6 -4
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +3 -2
- mindspore/dataset/transforms/c_transforms.py +1 -1
- mindspore/dataset/transforms/transforms.py +2 -2
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +2 -3
- mindspore/dataset/vision/c_transforms.py +9 -9
- mindspore/dataset/vision/py_transforms.py +5 -5
- mindspore/dataset/vision/py_transforms_util.py +2 -0
- mindspore/dataset/vision/transforms.py +160 -161
- mindspore/dataset/vision/utils.py +3 -3
- mindspore/experimental/map_parameter.py +38 -26
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +44 -9
- mindspore/include/api/delegate.h +1 -1
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_parallel_runner.h +2 -2
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +19 -3
- mindspore/include/api/types.h +3 -3
- mindspore/include/dataset/constants.h +7 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filereader.py +18 -0
- mindspore/mindrecord/filewriter.py +197 -34
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
- mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
- mindspore/mindrecord/tools/csv_to_mr.py +3 -3
- mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/__init__.py +0 -4
- mindspore/nn/cell.py +204 -132
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +7 -6
- mindspore/nn/layer/__init__.py +5 -4
- mindspore/nn/layer/activation.py +40 -89
- mindspore/nn/layer/basic.py +255 -624
- mindspore/nn/layer/channel_shuffle.py +7 -6
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +41 -4
- mindspore/nn/layer/conv.py +64 -28
- mindspore/nn/layer/dense.py +9 -8
- mindspore/nn/layer/embedding.py +27 -25
- mindspore/nn/layer/image.py +53 -46
- mindspore/nn/layer/math.py +97 -105
- mindspore/nn/layer/normalization.py +117 -86
- mindspore/nn/layer/padding.py +185 -95
- mindspore/nn/layer/pooling.py +817 -414
- mindspore/nn/layer/rnn_cells.py +10 -15
- mindspore/nn/layer/rnns.py +37 -38
- mindspore/nn/layer/thor_layer.py +11 -12
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +5 -4
- mindspore/nn/loss/loss.py +334 -199
- mindspore/nn/optim/ada_grad.py +6 -6
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +4 -5
- mindspore/nn/optim/adam.py +126 -62
- mindspore/nn/optim/adamax.py +3 -4
- mindspore/nn/optim/adasum.py +6 -6
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +67 -38
- mindspore/nn/optim/lamb.py +4 -5
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +43 -4
- mindspore/nn/optim/momentum.py +6 -5
- mindspore/nn/optim/optimizer.py +3 -1
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +1 -1
- mindspore/nn/optim/rprop.py +8 -9
- mindspore/nn/optim/sgd.py +19 -13
- mindspore/nn/optim/thor.py +10 -15
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +4 -4
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +9 -15
- mindspore/nn/probability/distribution/bernoulli.py +3 -3
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +5 -7
- mindspore/nn/probability/distribution/cauchy.py +3 -3
- mindspore/nn/probability/distribution/distribution.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +3 -3
- mindspore/nn/probability/distribution/half_normal.py +15 -11
- mindspore/nn/probability/distribution/laplace.py +16 -13
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/normal.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/student_t.py +20 -15
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +27 -10
- mindspore/nn/wrap/grad_reducer.py +2 -2
- mindspore/nn/wrap/loss_scale.py +40 -24
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +35 -30
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +22 -19
- mindspore/numpy/utils.py +1 -1
- mindspore/numpy/utils_const.py +108 -58
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +86 -117
- mindspore/ops/_grad/grad_base.py +23 -1
- mindspore/ops/_grad/grad_clip_ops.py +2 -3
- mindspore/ops/_grad/grad_comm_ops.py +34 -24
- mindspore/ops/_grad/grad_implementations.py +9 -45
- mindspore/ops/_grad/grad_inner_ops.py +47 -4
- mindspore/ops/_grad/grad_math_ops.py +142 -117
- mindspore/ops/_grad/grad_nn_ops.py +71 -165
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +7 -6
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
- mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
- mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
- mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -611
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_register_for_op.py +1 -0
- mindspore/ops/_utils/__init__.py +1 -2
- mindspore/ops/_utils/utils.py +19 -40
- mindspore/ops/_vmap/vmap_array_ops.py +116 -38
- mindspore/ops/_vmap/vmap_base.py +16 -9
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
- mindspore/ops/_vmap/vmap_image_ops.py +12 -5
- mindspore/ops/_vmap/vmap_math_ops.py +46 -5
- mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
- mindspore/ops/_vmap/vmap_random_ops.py +1 -1
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
- mindspore/ops/composite/__init__.py +7 -8
- mindspore/ops/composite/base.py +101 -47
- mindspore/ops/composite/math_ops.py +188 -158
- mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
- mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
- mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
- mindspore/ops/function/__init__.py +152 -8
- mindspore/ops/function/array_func.py +2555 -674
- mindspore/ops/function/clip_func.py +209 -13
- mindspore/ops/function/debug_func.py +2 -2
- mindspore/ops/function/grad/__init__.py +2 -1
- mindspore/ops/function/grad/grad_func.py +147 -62
- mindspore/ops/function/image_func.py +54 -38
- mindspore/ops/function/linalg_func.py +167 -16
- mindspore/ops/function/math_func.py +4849 -1492
- mindspore/ops/function/nn_func.py +2573 -988
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +3 -3
- mindspore/ops/function/random_func.py +790 -73
- mindspore/ops/function/sparse_func.py +98 -78
- mindspore/ops/function/sparse_unary_func.py +54 -53
- mindspore/ops/function/spectral_func.py +27 -24
- mindspore/ops/function/vmap_func.py +22 -2
- mindspore/ops/functional.py +97 -37
- mindspore/ops/op_info_register.py +70 -28
- mindspore/ops/operations/__init__.py +47 -14
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +5 -5
- mindspore/ops/operations/_grad_ops.py +276 -187
- mindspore/ops/operations/_inner_ops.py +319 -113
- mindspore/ops/operations/_ms_kernel.py +10 -8
- mindspore/ops/operations/_ocr_ops.py +9 -9
- mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
- mindspore/ops/operations/_quant_ops.py +137 -102
- mindspore/ops/operations/_rl_inner_ops.py +121 -60
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1004 -2
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +801 -466
- mindspore/ops/operations/comm_ops.py +51 -49
- mindspore/ops/operations/control_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +123 -44
- mindspore/ops/operations/debug_ops.py +24 -24
- mindspore/ops/operations/image_ops.py +240 -153
- mindspore/ops/operations/inner_ops.py +34 -50
- mindspore/ops/operations/linalg_ops.py +31 -9
- mindspore/ops/operations/math_ops.py +988 -757
- mindspore/ops/operations/nn_ops.py +965 -819
- mindspore/ops/operations/other_ops.py +51 -40
- mindspore/ops/operations/random_ops.py +204 -122
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +254 -93
- mindspore/ops/operations/spectral_ops.py +35 -3
- mindspore/ops/primitive.py +111 -9
- mindspore/parallel/_auto_parallel_context.py +189 -83
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +99 -7
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +7 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
- mindspore/parallel/_utils.py +1 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +37 -34
- mindspore/parallel/shard.py +17 -18
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +69 -47
- mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
- mindspore/profiler/parser/base_timeline_generator.py +49 -56
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
- mindspore/profiler/parser/hwts_log_parser.py +1 -1
- mindspore/profiler/parser/integrator.py +15 -14
- mindspore/profiler/parser/minddata_analyzer.py +2 -2
- mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +2 -1
- mindspore/profiler/profiling.py +218 -186
- mindspore/rewrite/__init__.py +3 -1
- mindspore/rewrite/api/node.py +1 -114
- mindspore/rewrite/api/node_type.py +3 -0
- mindspore/rewrite/api/pattern_engine.py +31 -1
- mindspore/rewrite/api/scoped_value.py +4 -4
- mindspore/rewrite/api/symbol_tree.py +3 -78
- mindspore/rewrite/api/tree_node_helper.py +1 -1
- mindspore/rewrite/ast_creator_register.py +1 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -2
- mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
- mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
- mindspore/rewrite/namespace.py +0 -2
- mindspore/rewrite/node.py +157 -11
- mindspore/rewrite/parsers/assign_parser.py +231 -53
- mindspore/rewrite/parsers/class_def_parser.py +187 -109
- mindspore/rewrite/parsers/for_parser.py +24 -14
- mindspore/rewrite/parsers/function_def_parser.py +21 -4
- mindspore/rewrite/parsers/if_parser.py +6 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +256 -133
- mindspore/rewrite/symbol_tree_builder.py +38 -1
- mindspore/run_check/_check_version.py +69 -63
- mindspore/run_check/run_check.py +2 -1
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +1 -1
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +273 -102
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +2 -2
- mindspore/train/callback/_checkpoint.py +3 -3
- mindspore/train/callback/_early_stop.py +3 -3
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +29 -31
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +3 -3
- mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
- mindspore/train/callback/_summary_collector.py +23 -16
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +15 -3
- mindspore/train/dataset_helper.py +10 -15
- mindspore/train/loss_scale_manager.py +8 -11
- mindspore/train/metrics/__init__.py +1 -1
- mindspore/train/metrics/bleu_score.py +1 -1
- mindspore/train/metrics/confusion_matrix.py +1 -1
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/dice.py +2 -2
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +4 -3
- mindspore/train/metrics/mean_surface_distance.py +2 -2
- mindspore/train/metrics/occlusion_sensitivity.py +1 -1
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +1 -1
- mindspore/train/metrics/recall.py +1 -1
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +45 -28
- mindspore/train/serialization.py +295 -188
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -13
- mindspore/train/train_thor/convert_utils.py +2 -2
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
- mindspore/compression/__init__.py +0 -19
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -515
- mindspore/compression/quant/__init__.py +0 -28
- mindspore/compression/quant/qat.py +0 -634
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -140
- mindspore/nn/probability/dpn/vae/vae.py +0 -124
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
- mindspore/ops/composite/array_ops.py +0 -241
- mindspore/ops/composite/clip_ops.py +0 -134
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
mindspore/nn/layer/rnn_cells.py
CHANGED
|
@@ -24,9 +24,9 @@ from mindspore import log as logger
|
|
|
24
24
|
from mindspore.common.tensor import Tensor
|
|
25
25
|
from mindspore.common.parameter import Parameter
|
|
26
26
|
from mindspore.common.initializer import initializer, Uniform
|
|
27
|
-
from mindspore.ops.primitive import constexpr
|
|
27
|
+
from mindspore.ops.primitive import constexpr, _primexpr
|
|
28
28
|
from mindspore.nn.cell import Cell
|
|
29
|
-
from mindspore
|
|
29
|
+
from mindspore import _checkparam as validator
|
|
30
30
|
|
|
31
31
|
__all__ = ['LSTMCell', 'GRUCell', 'RNNCell']
|
|
32
32
|
|
|
@@ -60,7 +60,7 @@ def _check_tuple_length(param_name, input_data, length, cls_name):
|
|
|
60
60
|
f"but got '{len(input_data)}'")
|
|
61
61
|
|
|
62
62
|
|
|
63
|
-
@
|
|
63
|
+
@_primexpr
|
|
64
64
|
def _check_batch_size_equal(batch_size_x, batch_size_hx, cls_name):
|
|
65
65
|
if batch_size_x != batch_size_hx:
|
|
66
66
|
raise ValueError(f"For '{cls_name}' batch size of x and hx must be equal, but got {batch_size_x} of x "
|
|
@@ -175,7 +175,7 @@ class RNNCell(RNNCellBase):
|
|
|
175
175
|
|
|
176
176
|
Here :math:`h_t` is the hidden state at time `t`, :math:`x_t` is
|
|
177
177
|
the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the
|
|
178
|
-
previous layer at time
|
|
178
|
+
previous layer at time :math:`t-1` or the initial hidden state at time `0`.
|
|
179
179
|
If `nonlinearity` is `relu`, then `relu` is used instead of `tanh`.
|
|
180
180
|
|
|
181
181
|
Args:
|
|
@@ -266,12 +266,12 @@ class LSTMCell(RNNCellBase):
|
|
|
266
266
|
has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True.
|
|
267
267
|
|
|
268
268
|
Inputs:
|
|
269
|
-
- **x** (Tensor) - Tensor of shape (
|
|
269
|
+
- **x** (Tensor) - Tensor of shape :math:`(batch\_size, input\_size)`.
|
|
270
270
|
- **hx** (tuple) - A tuple of two Tensors (h_0, c_0) both of data type mindspore.float32
|
|
271
|
-
and shape (
|
|
271
|
+
and shape :math:`(batch\_size, hidden\_size)`. The data type of `hx` must be the same as `x`.
|
|
272
272
|
|
|
273
273
|
Outputs:
|
|
274
|
-
- **hx'** (Tensor) - A tuple of two Tensors (h', c') both of data shape (
|
|
274
|
+
- **hx'** (Tensor) - A tuple of two Tensors (h', c') both of data shape :math:`(batch\_size, hidden\_size)`.
|
|
275
275
|
|
|
276
276
|
Raises:
|
|
277
277
|
TypeError: If `input_size`, `hidden_size` is not an int.
|
|
@@ -340,23 +340,18 @@ class GRUCell(RNNCellBase):
|
|
|
340
340
|
`Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation
|
|
341
341
|
<https://aclanthology.org/D14-1179.pdf>`_.
|
|
342
342
|
|
|
343
|
-
The LSTMCell can be simplified in NN layer, the following formula:
|
|
344
|
-
|
|
345
|
-
.. math::
|
|
346
|
-
h^{'},c^{'} = LSTMCell(x, (h_0, c_0))
|
|
347
|
-
|
|
348
343
|
Args:
|
|
349
344
|
input_size (int): Number of features of input.
|
|
350
345
|
hidden_size (int): Number of features of hidden layer.
|
|
351
346
|
has_bias (bool): Whether the cell has bias `b_in` and `b_hn`. Default: True.
|
|
352
347
|
|
|
353
348
|
Inputs:
|
|
354
|
-
- **x** (Tensor) - Tensor of shape (
|
|
355
|
-
- **hx** (Tensor) - Tensor of data type mindspore.float32 and shape (
|
|
349
|
+
- **x** (Tensor) - Tensor of shape :math:`(batch\_size, input\_size)`.
|
|
350
|
+
- **hx** (Tensor) - Tensor of data type mindspore.float32 and shape :math:`(batch\_size, hidden\_size)`.
|
|
356
351
|
Data type of `hx` must be the same as `x`.
|
|
357
352
|
|
|
358
353
|
Outputs:
|
|
359
|
-
- **hx'** (Tensor) - Tensor of shape (
|
|
354
|
+
- **hx'** (Tensor) - Tensor of shape :math:`(batch\_size, hidden\_size)`.
|
|
360
355
|
|
|
361
356
|
Raises:
|
|
362
357
|
TypeError: If `input_size`, `hidden_size` is not an int.
|
mindspore/nn/layer/rnns.py
CHANGED
|
@@ -22,28 +22,23 @@ import mindspore.nn as nn
|
|
|
22
22
|
import mindspore.ops as P
|
|
23
23
|
import mindspore.context as context
|
|
24
24
|
import mindspore.common.dtype as mstype
|
|
25
|
-
from mindspore.ops
|
|
25
|
+
from mindspore.ops import functional as F
|
|
26
|
+
from mindspore.ops.primitive import constexpr, _primexpr
|
|
26
27
|
from mindspore.common.tensor import Tensor
|
|
27
28
|
from mindspore.common.parameter import ParameterTuple, Parameter
|
|
28
29
|
from mindspore.nn.cell import Cell
|
|
29
30
|
from mindspore import log as logger
|
|
30
|
-
from mindspore
|
|
31
|
+
from mindspore import _checkparam as validator
|
|
31
32
|
from mindspore.ops.operations._rl_inner_ops import CudnnGRU
|
|
32
33
|
from mindspore.nn.layer.rnn_cells import _rnn_relu_cell, _rnn_tanh_cell, _gru_cell, _lstm_cell
|
|
33
|
-
from mindspore.nn.layer.rnn_utils import _Reverse, _ReverseSequence
|
|
34
34
|
|
|
35
35
|
__all__ = ['LSTM', 'GRU', 'RNN']
|
|
36
36
|
|
|
37
37
|
|
|
38
|
-
@
|
|
39
|
-
def arange(start, stop, step, dtype):
|
|
40
|
-
return Tensor(np.arange(start, stop, step), dtype)
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
@constexpr
|
|
38
|
+
@_primexpr
|
|
44
39
|
def _init_state(shape, dtype, is_lstm):
|
|
45
|
-
hx =
|
|
46
|
-
cx =
|
|
40
|
+
hx = P.Zeros()(shape, dtype)
|
|
41
|
+
cx = P.Zeros()(shape, dtype)
|
|
47
42
|
if is_lstm:
|
|
48
43
|
return (hx, cx)
|
|
49
44
|
return hx
|
|
@@ -84,7 +79,7 @@ def _check_tuple_length(param_name, input_data, length, cls_name):
|
|
|
84
79
|
f"but got '{len(input_data)}'")
|
|
85
80
|
|
|
86
81
|
|
|
87
|
-
@
|
|
82
|
+
@_primexpr
|
|
88
83
|
def _check_seq_length_size(batch_size_x, seq_length_size, cls_name):
|
|
89
84
|
if batch_size_x != seq_length_size:
|
|
90
85
|
raise ValueError(f"For '{cls_name}' batch size of x and seq_length must be equal, "
|
|
@@ -93,7 +88,7 @@ def _check_seq_length_size(batch_size_x, seq_length_size, cls_name):
|
|
|
93
88
|
|
|
94
89
|
def sequence_mask(lengths, maxlen):
|
|
95
90
|
"""generate mask matrix by seq_length"""
|
|
96
|
-
range_vector = arange(0, maxlen, 1, lengths.dtype)
|
|
91
|
+
range_vector = P.arange(start=0, end=maxlen, step=1, dtype=lengths.dtype)
|
|
97
92
|
result = range_vector < lengths.view(lengths.shape + (1,))
|
|
98
93
|
return result.astype(mstype.int32)
|
|
99
94
|
|
|
@@ -106,7 +101,7 @@ def select_by_mask(inputs, mask):
|
|
|
106
101
|
|
|
107
102
|
def get_hidden(output, seq_length):
|
|
108
103
|
"""get hidden state by seq_length"""
|
|
109
|
-
batch_index = arange(0, seq_length.shape[0], 1, seq_length.dtype)
|
|
104
|
+
batch_index = P.arange(start=0, end=seq_length.shape[0], step=1, dtype=seq_length.dtype)
|
|
110
105
|
indices = P.Concat(1)((seq_length.view(-1, 1) - 1, batch_index.view(-1, 1)))
|
|
111
106
|
return P.GatherNd()(output, indices)
|
|
112
107
|
|
|
@@ -158,7 +153,7 @@ class _DynamicRNNBase(Cell):
|
|
|
158
153
|
hidden_size = h.shape[-1]
|
|
159
154
|
zero_output = P.ZerosLike()(h_t)
|
|
160
155
|
seq_length = P.Cast()(seq_length, mstype.float32)
|
|
161
|
-
seq_length =
|
|
156
|
+
seq_length = F.broadcast_to(seq_length, (hidden_size, -1))
|
|
162
157
|
seq_length = P.Cast()(seq_length, mstype.int32)
|
|
163
158
|
seq_length = P.Transpose()(seq_length, (1, 0))
|
|
164
159
|
|
|
@@ -220,6 +215,7 @@ class _DynamicGRUCPUGPU(Cell):
|
|
|
220
215
|
self.is_gpu = context.get_context("device_target") == "GPU"
|
|
221
216
|
|
|
222
217
|
def construct(self, x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh):
|
|
218
|
+
'''_DynamicGRUCPUGPU'''
|
|
223
219
|
gate_size, input_size = w_ih.shape
|
|
224
220
|
hidden_size = gate_size // 3
|
|
225
221
|
if self.is_gpu:
|
|
@@ -262,15 +258,16 @@ class _DynamicGRUAscend(Cell):
|
|
|
262
258
|
self.dtype = mstype.float16
|
|
263
259
|
|
|
264
260
|
def construct(self, x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh):
|
|
261
|
+
'''Dynamic GRU module on Ascend'''
|
|
265
262
|
if b_ih is None:
|
|
266
263
|
b_ih = P.Zeros()(w_ih.shape[0], w_ih.dtype)
|
|
267
264
|
b_hh = P.Zeros()(w_ih.shape[0], w_ih.dtype)
|
|
268
265
|
outputs, _, _, _, _, _ = self.gru(self.cast(x, self.dtype), \
|
|
269
266
|
self.cast(self.transpose(w_ih, (1, 0)), self.dtype), \
|
|
270
267
|
self.cast(self.transpose(w_hh, (1, 0)), self.dtype), \
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
None,
|
|
268
|
+
b_ih, \
|
|
269
|
+
b_hh, \
|
|
270
|
+
None, h_0)
|
|
274
271
|
if seq_length is not None:
|
|
275
272
|
h = get_hidden(outputs, seq_length)
|
|
276
273
|
mask = sequence_mask(seq_length, x.shape[0])
|
|
@@ -289,6 +286,7 @@ class _DynamicLSTMCPUGPU(Cell):
|
|
|
289
286
|
self.is_gpu = context.get_context("device_target") == "GPU"
|
|
290
287
|
|
|
291
288
|
def construct(self, x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh):
|
|
289
|
+
'''Dynamic LSTM module on CPU and GPU'''
|
|
292
290
|
gate_size, input_size = w_ih.shape
|
|
293
291
|
hidden_size = gate_size // 4
|
|
294
292
|
if seq_length is not None:
|
|
@@ -339,6 +337,7 @@ class _DynamicLSTMAscend(Cell):
|
|
|
339
337
|
self.dtype = mstype.float16
|
|
340
338
|
|
|
341
339
|
def construct(self, x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh):
|
|
340
|
+
'''Dynamic LSTM module on Ascend'''
|
|
342
341
|
w_ih_i, w_ih_f, w_ih_g, w_ih_o = self.split(w_ih)
|
|
343
342
|
w_hh_i, w_hh_f, w_hh_g, w_hh_o = self.split(w_hh)
|
|
344
343
|
w_ih = self.concat_dim0((w_ih_i, w_ih_g, w_ih_f, w_ih_o))
|
|
@@ -414,17 +413,13 @@ class _RNNBase(Cell):
|
|
|
414
413
|
raise ValueError(f"For '{self.cls_name}', the 'mode' must be in ['RNN_RELU', 'RNN_TANH', 'LSTM', 'GRU'], "
|
|
415
414
|
f"but got {mode}.")
|
|
416
415
|
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
self.reverse_sequence = _ReverseSequence(0, 1)
|
|
420
|
-
else:
|
|
421
|
-
self.reverse = P.ReverseV2([0])
|
|
422
|
-
self.reverse_sequence = P.ReverseSequence(0, 1)
|
|
416
|
+
self.reverse = P.ReverseV2([0])
|
|
417
|
+
self.reverse_sequence = P.ReverseSequence(0, 1)
|
|
423
418
|
self.hidden_size = hidden_size
|
|
424
419
|
self.batch_first = batch_first
|
|
425
420
|
self.num_layers = num_layers
|
|
426
421
|
self.dropout = dropout
|
|
427
|
-
self.dropout_op = nn.Dropout(float(
|
|
422
|
+
self.dropout_op = nn.Dropout(p=float(dropout))
|
|
428
423
|
self.bidirectional = bidirectional
|
|
429
424
|
self.has_bias = has_bias
|
|
430
425
|
num_directions = 2 if bidirectional else 1
|
|
@@ -503,9 +498,11 @@ class _RNNBase(Cell):
|
|
|
503
498
|
if self.is_lstm:
|
|
504
499
|
h_n = P.Concat(0)(h_n)
|
|
505
500
|
c_n = P.Concat(0)(c_n)
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
501
|
+
h0_shape = h[0].shape
|
|
502
|
+
h1_shape = h[1].shape
|
|
503
|
+
h_n = h_n.view(h0_shape)
|
|
504
|
+
c_n = c_n.view(h1_shape)
|
|
505
|
+
return output, (h_n.view(h0_shape), c_n.view(h1_shape))
|
|
509
506
|
h_n = P.Concat(0)(h_n)
|
|
510
507
|
return output, h_n.view(h.shape)
|
|
511
508
|
|
|
@@ -535,9 +532,11 @@ class _RNNBase(Cell):
|
|
|
535
532
|
if self.is_lstm:
|
|
536
533
|
h_n = P.Concat(0)(h_n)
|
|
537
534
|
c_n = P.Concat(0)(c_n)
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
535
|
+
h0_shape = h[0].shape
|
|
536
|
+
h1_shape = h[1].shape
|
|
537
|
+
h_n = h_n.view(h0_shape)
|
|
538
|
+
c_n = c_n.view(h1_shape)
|
|
539
|
+
return output, (h_n.view(h0_shape), c_n.view(h1_shape))
|
|
541
540
|
h_n = P.Concat(0)(h_n)
|
|
542
541
|
return output, h_n.view(h.shape)
|
|
543
542
|
|
|
@@ -591,7 +590,7 @@ class RNN(_RNNBase):
|
|
|
591
590
|
|
|
592
591
|
Here :math:`h_t` is the hidden state at time `t`, :math:`x_t` is
|
|
593
592
|
the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the
|
|
594
|
-
previous layer at time
|
|
593
|
+
previous layer at time :math:`t-1` or the initial hidden state at time `0`.
|
|
595
594
|
If :attr:`nonlinearity` is ``'relu'``, then :math:`\text{ReLU}` is used instead of :math:`\tanh`.
|
|
596
595
|
|
|
597
596
|
Args:
|
|
@@ -671,7 +670,7 @@ class GRU(_RNNBase):
|
|
|
671
670
|
Given an input :math:`x_t` at time :math:`t`, a hidden state :math:`h_{t-1}`, the update and reset gate at
|
|
672
671
|
time :math:`t` is computed using a gating mechanism. Update gate :math:`z_t` is designed to protect the cell
|
|
673
672
|
from perturbation by irrelevant inputs and past hidden state. Reset gate :math:`r_t` determines how much
|
|
674
|
-
information should be reset from old hidden state. New memory state :math:`
|
|
673
|
+
information should be reset from old hidden state. New memory state :math:`n_t` is
|
|
675
674
|
calculated with the current input, on which the reset gate will be applied. Finally, current hidden state
|
|
676
675
|
:math:`h_{t}` is computed with the calculated update grate and new memory state. The complete
|
|
677
676
|
formulation is as follows:
|
|
@@ -805,12 +804,12 @@ class LSTM(_RNNBase):
|
|
|
805
804
|
|
|
806
805
|
Inputs:
|
|
807
806
|
- **x** (Tensor) - Tensor of data type mindspore.float32 or mindspore.float16 and
|
|
808
|
-
shape (
|
|
807
|
+
shape :math:`(seq\_len, batch\_size, input\_size)` or :math:`(batch\_size, seq\_len, input\_size)`.
|
|
809
808
|
- **hx** (tuple) - A tuple of two Tensors (h_0, c_0) both of data type mindspore.float32
|
|
810
|
-
or mindspore.float16 and shape (
|
|
809
|
+
or mindspore.float16 and shape :math:`(num\_directions * num\_layers, batch\_size, hidden\_size)`.
|
|
811
810
|
The data type of `hx` must be the same as `x`.
|
|
812
811
|
- **seq_length** (Tensor) - The length of each sequence in an input batch.
|
|
813
|
-
Tensor of shape :math:`(\
|
|
812
|
+
Tensor of shape :math:`(batch\_size)`. Default: None.
|
|
814
813
|
This input indicates the real sequence length before padding to avoid padded elements
|
|
815
814
|
have been used to compute hidden state and affect the final output. It is recommended to
|
|
816
815
|
use this input when **x** has padding elements.
|
|
@@ -818,9 +817,9 @@ class LSTM(_RNNBase):
|
|
|
818
817
|
Outputs:
|
|
819
818
|
Tuple, a tuple contains (`output`, (`h_n`, `c_n`)).
|
|
820
819
|
|
|
821
|
-
- **output** (Tensor) - Tensor of shape (
|
|
820
|
+
- **output** (Tensor) - Tensor of shape :math:`(seq\_len, batch\_size, num\_directions * hidden\_size)` .
|
|
822
821
|
- **hx_n** (tuple) - A tuple of two Tensor (h_n, c_n) both of shape
|
|
823
|
-
(
|
|
822
|
+
:math:`(num\_directions * num\_layers, batch\_size, hidden\_size)` .
|
|
824
823
|
|
|
825
824
|
Raises:
|
|
826
825
|
TypeError: If `input_size`, `hidden_size` or `num_layers` is not an int.
|
mindspore/nn/layer/thor_layer.py
CHANGED
|
@@ -25,7 +25,8 @@ from mindspore.communication.management import get_group_size, get_rank
|
|
|
25
25
|
from mindspore.ops import operations as P
|
|
26
26
|
from mindspore.ops.operations._thor_ops import ThorIm2Col
|
|
27
27
|
from mindspore.common.parameter import Parameter
|
|
28
|
-
from mindspore
|
|
28
|
+
from mindspore import _checkparam as Validator
|
|
29
|
+
from mindspore._checkparam import twice
|
|
29
30
|
from mindspore import context
|
|
30
31
|
from mindspore.nn.cell import Cell
|
|
31
32
|
from mindspore.nn.layer.activation import get_activation
|
|
@@ -33,9 +34,9 @@ from mindspore.parallel._ps_context import _is_role_worker, _get_ps_context, \
|
|
|
33
34
|
_set_rank_id, _insert_hash_table_size, _set_cache_enable
|
|
34
35
|
from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch
|
|
35
36
|
from mindspore.context import ParallelMode
|
|
36
|
-
from mindspore.ops.primitive import constexpr
|
|
37
37
|
from mindspore.ops import functional as F
|
|
38
38
|
from mindspore.nn.layer.basic import ClipByNorm
|
|
39
|
+
from mindspore.ops.primitive import constexpr
|
|
39
40
|
|
|
40
41
|
__all__ = ['DenseThor', 'Conv2dThor', 'EmbeddingThor', 'EmbeddingLookupThor']
|
|
41
42
|
|
|
@@ -45,8 +46,8 @@ class DenseThor(Cell):
|
|
|
45
46
|
The dense connected layer and saving the information needed for THOR.
|
|
46
47
|
|
|
47
48
|
Applies dense connected layer for the input and saves the information A and G in the dense connected layer
|
|
48
|
-
needed for THOR
|
|
49
|
-
|
|
49
|
+
needed for THOR.
|
|
50
|
+
|
|
50
51
|
This layer implements the operation as:
|
|
51
52
|
|
|
52
53
|
.. math::
|
|
@@ -283,7 +284,6 @@ class Conv2dThor(_ConvThor):
|
|
|
283
284
|
Applies a 2D convolution over an input tensor which is typically of shape :math:`(N, C_{in}, H_{in}, W_{in})`,
|
|
284
285
|
where :math:`N` is batch size, :math:`C_{in}` is channel number, and :math:`H_{in}, W_{in})` are height and width.
|
|
285
286
|
And saves the information A and G in the 2D convolution layer needed for THOR.
|
|
286
|
-
The detail can be seen in paper: https://www.aaai.org/AAAI21Papers/AAAI-6611.ChenM.pdf
|
|
287
287
|
|
|
288
288
|
For each batch of shape :math:`(C_{in}, H_{in}, W_{in})`, the formula is defined as:
|
|
289
289
|
|
|
@@ -434,8 +434,8 @@ class Conv2dThor(_ConvThor):
|
|
|
434
434
|
"""Initialize depthwise conv2d op"""
|
|
435
435
|
if context.get_context("device_target") == "Ascend" and self.group > 1:
|
|
436
436
|
self.dilation = self._dilation
|
|
437
|
-
Validator.check_int('group', self.group, self.in_channels,
|
|
438
|
-
Validator.check_int('group', self.group, self.out_channels,
|
|
437
|
+
Validator.check_int('group', self.group, self.in_channels, Validator.EQ, self.cls_name)
|
|
438
|
+
Validator.check_int('group', self.group, self.out_channels, Validator.EQ, self.cls_name)
|
|
439
439
|
self.conv2d = P.DepthwiseConv2dNative(channel_multiplier=1,
|
|
440
440
|
kernel_size=self.kernel_size,
|
|
441
441
|
pad_mode=self.pad_mode,
|
|
@@ -540,7 +540,7 @@ class EmbeddingThor(Cell):
|
|
|
540
540
|
This module is often used to store word embeddings and retrieve them using
|
|
541
541
|
indices. The input to the module is a list of indices, and the output is
|
|
542
542
|
the corresponding word embeddings. And saves the information A and G in the dense connected layer
|
|
543
|
-
needed for THOR
|
|
543
|
+
needed for THOR.
|
|
544
544
|
|
|
545
545
|
Note:
|
|
546
546
|
When 'use_one_hot' is set to True, the type of the input `x` must be mindspore.int32.
|
|
@@ -588,9 +588,9 @@ class EmbeddingThor(Cell):
|
|
|
588
588
|
self.init_tensor = initializer(embedding_table, [vocab_size, embedding_size])
|
|
589
589
|
self.padding_idx = padding_idx
|
|
590
590
|
if padding_idx is not None:
|
|
591
|
-
self.padding_idx = Validator.check_int_range(padding_idx, 0, vocab_size,
|
|
591
|
+
self.padding_idx = Validator.check_int_range(padding_idx, 0, vocab_size, Validator.INC_BOTH,
|
|
592
592
|
"padding_idx", self.cls_name)
|
|
593
|
-
self.init_tensor = self.init_tensor.
|
|
593
|
+
self.init_tensor = self.init_tensor.init_data().asnumpy()
|
|
594
594
|
self.init_tensor[self.padding_idx] = 0
|
|
595
595
|
self.embedding_table = Parameter(self.init_tensor, name='embedding_table')
|
|
596
596
|
self.expand = P.ExpandDims()
|
|
@@ -671,8 +671,7 @@ class EmbeddingLookupThor(Cell):
|
|
|
671
671
|
and saving the information needed for THOR.
|
|
672
672
|
|
|
673
673
|
This module has the same function as EmbeddingLookup, but additionally saves the information A and G in the
|
|
674
|
-
embeddinglookup layer needed for THOR
|
|
675
|
-
the detail can be seen in paper: https://www.aaai.org/AAAI21Papers/AAAI-6611.ChenM.pdf
|
|
674
|
+
embeddinglookup layer needed for THOR.
|
|
676
675
|
|
|
677
676
|
|
|
678
677
|
Args:
|
|
@@ -15,16 +15,16 @@
|
|
|
15
15
|
"""Time Distributed."""
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
|
|
18
|
-
from mindspore.ops.primitive import constexpr, Primitive
|
|
18
|
+
from mindspore.ops.primitive import constexpr, Primitive, _primexpr
|
|
19
19
|
from mindspore.ops import Reshape, Transpose, Stack, Unstack
|
|
20
20
|
from mindspore.common import Tensor
|
|
21
|
-
from mindspore
|
|
21
|
+
from mindspore import _checkparam as Validator
|
|
22
22
|
from mindspore.nn.cell import Cell
|
|
23
23
|
|
|
24
24
|
__all__ = ['TimeDistributed']
|
|
25
25
|
|
|
26
26
|
|
|
27
|
-
@
|
|
27
|
+
@_primexpr
|
|
28
28
|
def _check_reshape_pos(reshape_pos, inputs_shape, outputs_shape, prim_name=None):
|
|
29
29
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
30
30
|
if reshape_pos >= len(outputs_shape) or inputs_shape[reshape_pos] != outputs_shape[reshape_pos]:
|
|
@@ -35,7 +35,7 @@ def _check_reshape_pos(reshape_pos, inputs_shape, outputs_shape, prim_name=None)
|
|
|
35
35
|
f"{outputs_shape}. You may try pass parameters without 'reshape_with_axis'.")
|
|
36
36
|
|
|
37
37
|
|
|
38
|
-
@
|
|
38
|
+
@_primexpr
|
|
39
39
|
def _check_expand_dims_axis(time_axis, ndim, prim_name=None):
|
|
40
40
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
41
41
|
if time_axis > ndim:
|
|
@@ -57,7 +57,7 @@ def _check_data(flag, prim_name=None):
|
|
|
57
57
|
raise TypeError(f"{msg_prefix} inputs and outputs must be a Tensor.")
|
|
58
58
|
|
|
59
59
|
|
|
60
|
-
@
|
|
60
|
+
@_primexpr
|
|
61
61
|
def _check_inputs_dim(shape, prim_name=None):
|
|
62
62
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
63
63
|
if len(shape) < 3:
|