mindspore 2.0.0a0__cp37-cp37m-win_amd64.whl → 2.0.0rc1__cp37-cp37m-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -2
- mindspore/_c_dataengine.cp37-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp37-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp37-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +16 -1
- mindspore/_extends/parse/parser.py +107 -22
- mindspore/_extends/parse/resources.py +0 -7
- mindspore/_extends/parse/standard_method.py +885 -413
- mindspore/amp.py +52 -57
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +38 -20
- mindspore/boost/dim_reduce.py +3 -3
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -6
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +41 -7
- mindspore/common/api.py +215 -141
- mindspore/common/dtype.py +8 -1
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +4 -2
- mindspore/common/jit_config.py +17 -13
- mindspore/common/mutable.py +33 -13
- mindspore/common/parameter.py +23 -21
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +62 -41
- mindspore/common/tensor.py +852 -1154
- mindspore/communication/__init__.py +2 -2
- mindspore/communication/_comm_helper.py +11 -4
- mindspore/communication/management.py +22 -21
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +201 -23
- mindspore/dataset/__init__.py +6 -6
- mindspore/dataset/audio/__init__.py +7 -7
- mindspore/dataset/audio/transforms.py +670 -30
- mindspore/dataset/audio/utils.py +47 -4
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +210 -14
- mindspore/dataset/core/validator_helpers.py +2 -2
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +322 -66
- mindspore/dataset/engine/datasets_audio.py +80 -76
- mindspore/dataset/engine/datasets_standard_format.py +51 -38
- mindspore/dataset/engine/datasets_text.py +232 -118
- mindspore/dataset/engine/datasets_user_defined.py +41 -17
- mindspore/dataset/engine/datasets_vision.py +746 -225
- mindspore/dataset/engine/graphdata.py +75 -10
- mindspore/dataset/engine/iterators.py +45 -5
- mindspore/dataset/engine/offload.py +48 -28
- mindspore/dataset/engine/validators.py +117 -8
- mindspore/dataset/text/__init__.py +6 -5
- mindspore/dataset/text/transforms.py +86 -3
- mindspore/dataset/text/utils.py +6 -4
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +3 -2
- mindspore/dataset/transforms/c_transforms.py +1 -1
- mindspore/dataset/transforms/transforms.py +2 -2
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +2 -3
- mindspore/dataset/vision/c_transforms.py +9 -9
- mindspore/dataset/vision/py_transforms.py +5 -5
- mindspore/dataset/vision/py_transforms_util.py +2 -0
- mindspore/dataset/vision/transforms.py +160 -161
- mindspore/dataset/vision/utils.py +3 -3
- mindspore/experimental/map_parameter.py +38 -26
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +44 -9
- mindspore/include/api/delegate.h +1 -1
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_parallel_runner.h +2 -2
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +19 -3
- mindspore/include/api/types.h +3 -3
- mindspore/include/dataset/constants.h +7 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filereader.py +18 -0
- mindspore/mindrecord/filewriter.py +197 -34
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
- mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
- mindspore/mindrecord/tools/csv_to_mr.py +3 -3
- mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/__init__.py +0 -4
- mindspore/nn/cell.py +204 -132
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +7 -6
- mindspore/nn/layer/__init__.py +5 -4
- mindspore/nn/layer/activation.py +40 -89
- mindspore/nn/layer/basic.py +255 -624
- mindspore/nn/layer/channel_shuffle.py +7 -6
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +41 -4
- mindspore/nn/layer/conv.py +64 -28
- mindspore/nn/layer/dense.py +9 -8
- mindspore/nn/layer/embedding.py +27 -25
- mindspore/nn/layer/image.py +53 -46
- mindspore/nn/layer/math.py +97 -105
- mindspore/nn/layer/normalization.py +117 -86
- mindspore/nn/layer/padding.py +185 -95
- mindspore/nn/layer/pooling.py +817 -414
- mindspore/nn/layer/rnn_cells.py +10 -15
- mindspore/nn/layer/rnns.py +37 -38
- mindspore/nn/layer/thor_layer.py +11 -12
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +5 -4
- mindspore/nn/loss/loss.py +334 -199
- mindspore/nn/optim/ada_grad.py +6 -6
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +4 -5
- mindspore/nn/optim/adam.py +126 -62
- mindspore/nn/optim/adamax.py +3 -4
- mindspore/nn/optim/adasum.py +6 -6
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +67 -38
- mindspore/nn/optim/lamb.py +4 -5
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +43 -4
- mindspore/nn/optim/momentum.py +6 -5
- mindspore/nn/optim/optimizer.py +3 -1
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +1 -1
- mindspore/nn/optim/rprop.py +8 -9
- mindspore/nn/optim/sgd.py +19 -13
- mindspore/nn/optim/thor.py +10 -15
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +4 -4
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +9 -15
- mindspore/nn/probability/distribution/bernoulli.py +3 -3
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +5 -7
- mindspore/nn/probability/distribution/cauchy.py +3 -3
- mindspore/nn/probability/distribution/distribution.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +3 -3
- mindspore/nn/probability/distribution/half_normal.py +15 -11
- mindspore/nn/probability/distribution/laplace.py +16 -13
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/normal.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/student_t.py +20 -15
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +27 -10
- mindspore/nn/wrap/grad_reducer.py +2 -2
- mindspore/nn/wrap/loss_scale.py +40 -24
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +35 -30
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +22 -19
- mindspore/numpy/utils.py +1 -1
- mindspore/numpy/utils_const.py +108 -58
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +86 -117
- mindspore/ops/_grad/grad_base.py +23 -1
- mindspore/ops/_grad/grad_clip_ops.py +2 -3
- mindspore/ops/_grad/grad_comm_ops.py +34 -24
- mindspore/ops/_grad/grad_implementations.py +9 -45
- mindspore/ops/_grad/grad_inner_ops.py +47 -4
- mindspore/ops/_grad/grad_math_ops.py +142 -117
- mindspore/ops/_grad/grad_nn_ops.py +71 -165
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +7 -6
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
- mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
- mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
- mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -611
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_register_for_op.py +1 -0
- mindspore/ops/_utils/__init__.py +1 -2
- mindspore/ops/_utils/utils.py +19 -40
- mindspore/ops/_vmap/vmap_array_ops.py +116 -38
- mindspore/ops/_vmap/vmap_base.py +16 -9
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
- mindspore/ops/_vmap/vmap_image_ops.py +12 -5
- mindspore/ops/_vmap/vmap_math_ops.py +46 -5
- mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
- mindspore/ops/_vmap/vmap_random_ops.py +1 -1
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
- mindspore/ops/composite/__init__.py +7 -8
- mindspore/ops/composite/base.py +101 -47
- mindspore/ops/composite/math_ops.py +188 -158
- mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
- mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
- mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
- mindspore/ops/function/__init__.py +152 -8
- mindspore/ops/function/array_func.py +2555 -674
- mindspore/ops/function/clip_func.py +209 -13
- mindspore/ops/function/debug_func.py +2 -2
- mindspore/ops/function/grad/__init__.py +2 -1
- mindspore/ops/function/grad/grad_func.py +147 -62
- mindspore/ops/function/image_func.py +54 -38
- mindspore/ops/function/linalg_func.py +167 -16
- mindspore/ops/function/math_func.py +4849 -1492
- mindspore/ops/function/nn_func.py +2573 -988
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +3 -3
- mindspore/ops/function/random_func.py +790 -73
- mindspore/ops/function/sparse_func.py +98 -78
- mindspore/ops/function/sparse_unary_func.py +54 -53
- mindspore/ops/function/spectral_func.py +27 -24
- mindspore/ops/function/vmap_func.py +22 -2
- mindspore/ops/functional.py +97 -37
- mindspore/ops/op_info_register.py +70 -28
- mindspore/ops/operations/__init__.py +47 -14
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +5 -5
- mindspore/ops/operations/_grad_ops.py +276 -187
- mindspore/ops/operations/_inner_ops.py +319 -113
- mindspore/ops/operations/_ms_kernel.py +10 -8
- mindspore/ops/operations/_ocr_ops.py +9 -9
- mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
- mindspore/ops/operations/_quant_ops.py +137 -102
- mindspore/ops/operations/_rl_inner_ops.py +121 -60
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1004 -2
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +801 -466
- mindspore/ops/operations/comm_ops.py +51 -49
- mindspore/ops/operations/control_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +123 -44
- mindspore/ops/operations/debug_ops.py +24 -24
- mindspore/ops/operations/image_ops.py +240 -153
- mindspore/ops/operations/inner_ops.py +34 -50
- mindspore/ops/operations/linalg_ops.py +31 -9
- mindspore/ops/operations/math_ops.py +988 -757
- mindspore/ops/operations/nn_ops.py +965 -819
- mindspore/ops/operations/other_ops.py +51 -40
- mindspore/ops/operations/random_ops.py +204 -122
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +254 -93
- mindspore/ops/operations/spectral_ops.py +35 -3
- mindspore/ops/primitive.py +111 -9
- mindspore/parallel/_auto_parallel_context.py +189 -83
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +99 -7
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +7 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
- mindspore/parallel/_utils.py +1 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +37 -34
- mindspore/parallel/shard.py +17 -18
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +69 -47
- mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
- mindspore/profiler/parser/base_timeline_generator.py +49 -56
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
- mindspore/profiler/parser/hwts_log_parser.py +1 -1
- mindspore/profiler/parser/integrator.py +15 -14
- mindspore/profiler/parser/minddata_analyzer.py +2 -2
- mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +2 -1
- mindspore/profiler/profiling.py +218 -186
- mindspore/rewrite/__init__.py +3 -1
- mindspore/rewrite/api/node.py +1 -114
- mindspore/rewrite/api/node_type.py +3 -0
- mindspore/rewrite/api/pattern_engine.py +31 -1
- mindspore/rewrite/api/scoped_value.py +4 -4
- mindspore/rewrite/api/symbol_tree.py +3 -78
- mindspore/rewrite/api/tree_node_helper.py +1 -1
- mindspore/rewrite/ast_creator_register.py +1 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -2
- mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
- mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
- mindspore/rewrite/namespace.py +0 -2
- mindspore/rewrite/node.py +157 -11
- mindspore/rewrite/parsers/assign_parser.py +231 -53
- mindspore/rewrite/parsers/class_def_parser.py +187 -109
- mindspore/rewrite/parsers/for_parser.py +24 -14
- mindspore/rewrite/parsers/function_def_parser.py +21 -4
- mindspore/rewrite/parsers/if_parser.py +6 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +256 -133
- mindspore/rewrite/symbol_tree_builder.py +38 -1
- mindspore/run_check/_check_version.py +69 -63
- mindspore/run_check/run_check.py +2 -1
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +1 -1
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +273 -102
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +2 -2
- mindspore/train/callback/_checkpoint.py +3 -3
- mindspore/train/callback/_early_stop.py +3 -3
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +29 -31
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +3 -3
- mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
- mindspore/train/callback/_summary_collector.py +23 -16
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +15 -3
- mindspore/train/dataset_helper.py +10 -15
- mindspore/train/loss_scale_manager.py +8 -11
- mindspore/train/metrics/__init__.py +1 -1
- mindspore/train/metrics/bleu_score.py +1 -1
- mindspore/train/metrics/confusion_matrix.py +1 -1
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/dice.py +2 -2
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +4 -3
- mindspore/train/metrics/mean_surface_distance.py +2 -2
- mindspore/train/metrics/occlusion_sensitivity.py +1 -1
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +1 -1
- mindspore/train/metrics/recall.py +1 -1
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +45 -28
- mindspore/train/serialization.py +295 -188
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -13
- mindspore/train/train_thor/convert_utils.py +2 -2
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
- mindspore/compression/__init__.py +0 -19
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -515
- mindspore/compression/quant/__init__.py +0 -28
- mindspore/compression/quant/qat.py +0 -634
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -140
- mindspore/nn/probability/dpn/vae/vae.py +0 -124
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
- mindspore/ops/composite/array_ops.py +0 -241
- mindspore/ops/composite/clip_ops.py +0 -134
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -18,13 +18,21 @@ from __future__ import absolute_import
|
|
|
18
18
|
from functools import partial
|
|
19
19
|
|
|
20
20
|
import mindspore.context as context
|
|
21
|
-
from mindspore
|
|
22
|
-
from mindspore._checkparam import Rel
|
|
21
|
+
from mindspore import _checkparam as validator
|
|
23
22
|
from mindspore.ops.primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
|
24
23
|
from mindspore.common import dtype as mstype
|
|
25
24
|
from mindspore.common.dtype import QuantDtype
|
|
26
25
|
|
|
27
|
-
|
|
26
|
+
|
|
27
|
+
def _support_te():
|
|
28
|
+
try:
|
|
29
|
+
import te # pylint: disable=unused-import
|
|
30
|
+
return True
|
|
31
|
+
# pylint: disable=broad-except
|
|
32
|
+
except Exception:
|
|
33
|
+
return False
|
|
34
|
+
|
|
35
|
+
if context.get_context('device_target') == "Ascend" and _support_te():
|
|
28
36
|
import mindspore.ops._op_impl._custom_op
|
|
29
37
|
|
|
30
38
|
__all__ = ["MinMaxUpdatePerLayer",
|
|
@@ -108,8 +116,22 @@ class FakeQuantParam(Primitive):
|
|
|
108
116
|
|
|
109
117
|
@classmethod
|
|
110
118
|
def linear_quant_param(cls, quant_dtype, scale, zp, is_per_channel=False, **kwargs):
|
|
111
|
-
|
|
112
|
-
|
|
119
|
+
"""
|
|
120
|
+
Create a linear quantization operator based on scale and zero-point parameter.
|
|
121
|
+
"""
|
|
122
|
+
validator.check_value_type("scale", scale, [float, tuple, list], "FakeQuantParam")
|
|
123
|
+
if isinstance(scale, float):
|
|
124
|
+
scale_list = [scale]
|
|
125
|
+
else:
|
|
126
|
+
scale_list = scale
|
|
127
|
+
validator.check_value_type("zero_point", zp, [int, tuple, list], "FakeQuantParam")
|
|
128
|
+
if isinstance(zp, int):
|
|
129
|
+
zp_list = [zp]
|
|
130
|
+
else:
|
|
131
|
+
zp_list = zp
|
|
132
|
+
validator.check_value_type("is_per_channel", is_per_channel, [bool], "FakeQuantParam")
|
|
133
|
+
kwargs[FakeQuantParam.attr_key_linear_quant_scale] = scale_list
|
|
134
|
+
kwargs[FakeQuantParam.attr_key_linear_quant_zero_point] = zp_list
|
|
113
135
|
return cls(quant_dtype, FakeQuantParam.attr_value_linear_quant_algo_name, is_per_channel, **kwargs)
|
|
114
136
|
|
|
115
137
|
|
|
@@ -147,14 +169,14 @@ class MinMaxUpdatePerLayer(PrimitiveWithInfer):
|
|
|
147
169
|
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
|
|
148
170
|
|
|
149
171
|
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
|
|
150
|
-
self.ema_decay = validator.check_float_range(ema_decay, 0, 1,
|
|
172
|
+
self.ema_decay = validator.check_float_range(ema_decay, 0, 1, validator.INC_BOTH, 'ema_decay', self.name)
|
|
151
173
|
self.init_prim_io_names(inputs=['x', 'min', 'max'],
|
|
152
174
|
outputs=['min_up', 'max_up'])
|
|
153
175
|
|
|
154
176
|
def infer_shape(self, x_shape, min_shape, max_shape):
|
|
155
|
-
validator.check_int(len(x_shape), 1,
|
|
177
|
+
validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
|
|
156
178
|
validator.check("min shape", min_shape, "max shape",
|
|
157
|
-
max_shape,
|
|
179
|
+
max_shape, validator.EQ, self.name)
|
|
158
180
|
validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
|
|
159
181
|
return min_shape, max_shape
|
|
160
182
|
|
|
@@ -203,9 +225,10 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
|
|
|
203
225
|
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
|
|
204
226
|
|
|
205
227
|
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
|
|
206
|
-
self.ema_decay = validator.check_float_range(ema_decay, 0, 1,
|
|
228
|
+
self.ema_decay = validator.check_float_range(ema_decay, 0, 1, validator.INC_BOTH, 'ema_decay', self.name)
|
|
207
229
|
if self.is_ascend:
|
|
208
|
-
self.channel_axis = validator.check_int_range(channel_axis, 0, 1,
|
|
230
|
+
self.channel_axis = validator.check_int_range(channel_axis, 0, 1, validator.INC_BOTH,
|
|
231
|
+
'channel_axis', self.name)
|
|
209
232
|
else:
|
|
210
233
|
self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
|
|
211
234
|
self.init_prim_io_names(
|
|
@@ -215,9 +238,9 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
|
|
|
215
238
|
if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
|
|
216
239
|
raise ValueError(f"For '{self.name}' x rank must be in '{self.ascend_support_x_rank}'")
|
|
217
240
|
if not self.is_ascend:
|
|
218
|
-
validator.check_int(len(x_shape), 1,
|
|
241
|
+
validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
|
|
219
242
|
validator.check("min shape", min_shape, "max shape",
|
|
220
|
-
max_shape,
|
|
243
|
+
max_shape, validator.EQ, self.name)
|
|
221
244
|
validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
|
|
222
245
|
return min_shape, max_shape
|
|
223
246
|
|
|
@@ -273,9 +296,9 @@ class FakeLearnedScaleQuantPerLayer(PrimitiveWithInfer):
|
|
|
273
296
|
outputs=['out'])
|
|
274
297
|
|
|
275
298
|
def infer_shape(self, input_x_shape, alpha_shape, quant_max_shape):
|
|
276
|
-
validator.check_int(len(input_x_shape), 1,
|
|
277
|
-
validator.check_int(len(alpha_shape), 1,
|
|
278
|
-
validator.check_int(len(quant_max_shape), 1,
|
|
299
|
+
validator.check_int(len(input_x_shape), 1, validator.GE, "input_x rank", self.name)
|
|
300
|
+
validator.check_int(len(alpha_shape), 1, validator.GE, "alpha rank", self.name)
|
|
301
|
+
validator.check_int(len(quant_max_shape), 1, validator.GE, "quant max rank", self.name)
|
|
279
302
|
return input_x_shape
|
|
280
303
|
|
|
281
304
|
def infer_dtype(self, input_x_type, alpha_type, quant_max_type):
|
|
@@ -314,9 +337,9 @@ class FakeLearnedScaleQuantPerLayerGrad(PrimitiveWithInfer):
|
|
|
314
337
|
inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha'])
|
|
315
338
|
|
|
316
339
|
def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape):
|
|
317
|
-
validator.check("dout shape", dout_shape, "x_shape", x_shape,
|
|
318
|
-
validator.check_int(len(alpha_shape), 1,
|
|
319
|
-
validator.check_int(len(quant_max_shape), 1,
|
|
340
|
+
validator.check("dout shape", dout_shape, "x_shape", x_shape, validator.EQ, self.name)
|
|
341
|
+
validator.check_int(len(alpha_shape), 1, validator.GE, "alpha rank", self.name)
|
|
342
|
+
validator.check_int(len(quant_max_shape), 1, validator.GE, "quant max rank", self.name)
|
|
320
343
|
return dout_shape, alpha_shape
|
|
321
344
|
|
|
322
345
|
def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type):
|
|
@@ -345,9 +368,9 @@ class FakeLearnedScaleQuantPerLayerGradD(PrimitiveWithInfer):
|
|
|
345
368
|
inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha'])
|
|
346
369
|
|
|
347
370
|
def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape):
|
|
348
|
-
validator.check("dout shape", dout_shape, "x_shape", x_shape,
|
|
349
|
-
validator.check_int(len(alpha_shape), 1,
|
|
350
|
-
validator.check_int(len(quant_max_shape), 1,
|
|
371
|
+
validator.check("dout shape", dout_shape, "x_shape", x_shape, validator.EQ, self.name)
|
|
372
|
+
validator.check_int(len(alpha_shape), 1, validator.GE, "alpha rank", self.name)
|
|
373
|
+
validator.check_int(len(quant_max_shape), 1, validator.GE, "quant max rank", self.name)
|
|
351
374
|
return dout_shape, dout_shape
|
|
352
375
|
|
|
353
376
|
def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type):
|
|
@@ -423,7 +446,8 @@ class FakeLearnedScaleQuantPerChannel(PrimitiveWithInfer):
|
|
|
423
446
|
self.training = validator.check_value_type(
|
|
424
447
|
'training', training, (bool,), self.name)
|
|
425
448
|
if self.is_ascend:
|
|
426
|
-
self.channel_axis = validator.check_int_range(channel_axis, 0, 1,
|
|
449
|
+
self.channel_axis = validator.check_int_range(channel_axis, 0, 1, validator.INC_BOTH,
|
|
450
|
+
'channel_axis', self.name)
|
|
427
451
|
else:
|
|
428
452
|
self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
|
|
429
453
|
self.init_prim_io_names(inputs=['input_x', 'alpha', 'quant_max'],
|
|
@@ -433,12 +457,12 @@ class FakeLearnedScaleQuantPerChannel(PrimitiveWithInfer):
|
|
|
433
457
|
if self.is_ascend and len(input_x_shape) not in self.ascend_support_x_rank:
|
|
434
458
|
raise ValueError(f"For '{self.name}' x rank must be in '{self.ascend_support_x_rank}'")
|
|
435
459
|
if not self.is_ascend:
|
|
436
|
-
validator.check_int(len(input_x_shape), 1,
|
|
460
|
+
validator.check_int(len(input_x_shape), 1, validator.GE, "input_x rank", self.name)
|
|
437
461
|
if len(input_x_shape) == 1:
|
|
438
462
|
self.channel_axis = 0
|
|
439
463
|
|
|
440
464
|
validator.check_equal_int(alpha_shape[0], input_x_shape[self.channel_axis], "alpha rank", self.name)
|
|
441
|
-
validator.check_int(len(quant_max_shape), 1,
|
|
465
|
+
validator.check_int(len(quant_max_shape), 1, validator.GE, "quant max rank", self.name)
|
|
442
466
|
return input_x_shape
|
|
443
467
|
|
|
444
468
|
def infer_dtype(self, input_x_type, alpha_type, quant_max_type):
|
|
@@ -479,7 +503,7 @@ class FakeLearnedScaleQuantPerChannelGrad(PrimitiveWithInfer):
|
|
|
479
503
|
inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha'])
|
|
480
504
|
|
|
481
505
|
def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape):
|
|
482
|
-
validator.check("dout shape", dout_shape, "x_shape", x_shape,
|
|
506
|
+
validator.check("dout shape", dout_shape, "x_shape", x_shape, validator.EQ, self.name)
|
|
483
507
|
return dout_shape, alpha_shape
|
|
484
508
|
|
|
485
509
|
def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type):
|
|
@@ -510,9 +534,9 @@ class FakeLearnedScaleQuantPerChannelGradD(PrimitiveWithInfer):
|
|
|
510
534
|
inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha'])
|
|
511
535
|
|
|
512
536
|
def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape):
|
|
513
|
-
validator.check("dout shape", dout_shape, "x_shape", x_shape,
|
|
514
|
-
validator.check_int(len(alpha_shape), 1,
|
|
515
|
-
validator.check_int(len(quant_max_shape), 1,
|
|
537
|
+
validator.check("dout shape", dout_shape, "x_shape", x_shape, validator.EQ, self.name)
|
|
538
|
+
validator.check_int(len(alpha_shape), 1, validator.GE, "alpha rank", self.name)
|
|
539
|
+
validator.check_int(len(quant_max_shape), 1, validator.GE, "quant max rank", self.name)
|
|
516
540
|
return dout_shape, dout_shape
|
|
517
541
|
|
|
518
542
|
def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type):
|
|
@@ -576,7 +600,7 @@ class FakeQuantWithMinMaxVars(PrimitiveWithInfer):
|
|
|
576
600
|
num_bits=8,
|
|
577
601
|
narrow_range=False):
|
|
578
602
|
self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
|
|
579
|
-
self.num_bits = validator.check_int_range(self.num_bits, 2, 16,
|
|
603
|
+
self.num_bits = validator.check_int_range(self.num_bits, 2, 16, validator.INC_BOTH, 'num_bits', self.name)
|
|
580
604
|
self.narrow_range = validator.check_value_type(
|
|
581
605
|
'narrow_range', narrow_range, (bool,), self.name)
|
|
582
606
|
|
|
@@ -588,9 +612,9 @@ class FakeQuantWithMinMaxVars(PrimitiveWithInfer):
|
|
|
588
612
|
raise ValueError(f"For '{self.name}', the shape of \'min\' cannot broadcast to the shape of \'x\'.")
|
|
589
613
|
|
|
590
614
|
def infer_shape(self, x_shape, min_shape, max_shape):
|
|
591
|
-
validator.check_int(len(x_shape), 1,
|
|
592
|
-
validator.check("min shape", min_shape, "max shape", max_shape,
|
|
593
|
-
validator.check_int(len(min_shape), 1,
|
|
615
|
+
validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
|
|
616
|
+
validator.check("min shape", min_shape, "max shape", max_shape, validator.EQ, self.name)
|
|
617
|
+
validator.check_int(len(min_shape), 1, validator.EQ, "min shape", self.name)
|
|
594
618
|
self.check_broadcast(min_shape, x_shape)
|
|
595
619
|
return x_shape
|
|
596
620
|
|
|
@@ -640,7 +664,7 @@ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer):
|
|
|
640
664
|
num_bits=8,
|
|
641
665
|
narrow_range=False):
|
|
642
666
|
self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
|
|
643
|
-
self.num_bits = validator.check_int_range(self.num_bits, 2, 16,
|
|
667
|
+
self.num_bits = validator.check_int_range(self.num_bits, 2, 16, validator.INC_BOTH, 'num_bits', self.name)
|
|
644
668
|
self.narrow_range = validator.check_value_type(
|
|
645
669
|
'narrow_range', narrow_range, (bool,), self.name)
|
|
646
670
|
|
|
@@ -652,10 +676,10 @@ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer):
|
|
|
652
676
|
raise ValueError(f"For '{self.name}', the shape of \'min\' cannot broadcast to the shape of \'x\'.")
|
|
653
677
|
|
|
654
678
|
def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
|
|
655
|
-
validator.check_int(len(x_shape), 1,
|
|
656
|
-
validator.check("dout shape", dout_shape, "x shape", x_shape,
|
|
657
|
-
validator.check("min shape", min_shape, "max shape", max_shape,
|
|
658
|
-
validator.check_int(len(min_shape), 1,
|
|
679
|
+
validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
|
|
680
|
+
validator.check("dout shape", dout_shape, "x shape", x_shape, validator.EQ, self.name)
|
|
681
|
+
validator.check("min shape", min_shape, "max shape", max_shape, validator.EQ, self.name)
|
|
682
|
+
validator.check_int(len(min_shape), 1, validator.EQ, "min shape", self.name)
|
|
659
683
|
self.check_broadcast(min_shape, x_shape)
|
|
660
684
|
return x_shape, min_shape, max_shape
|
|
661
685
|
|
|
@@ -699,15 +723,15 @@ class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer):
|
|
|
699
723
|
num_bits=8,
|
|
700
724
|
narrow_range=False):
|
|
701
725
|
self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
|
|
702
|
-
self.num_bits = validator.check_int_range(self.num_bits, 2, 16,
|
|
726
|
+
self.num_bits = validator.check_int_range(self.num_bits, 2, 16, validator.INC_BOTH, 'num_bits', self.name)
|
|
703
727
|
self.narrow_range = validator.check_value_type(
|
|
704
728
|
'narrow_range', narrow_range, (bool,), self.name)
|
|
705
729
|
|
|
706
730
|
def infer_shape(self, x_shape, min_shape, max_shape):
|
|
707
|
-
validator.check_int(len(x_shape), 1,
|
|
708
|
-
validator.check("min shape", min_shape, "max shape", max_shape,
|
|
709
|
-
validator.check_int(len(min_shape), 1,
|
|
710
|
-
validator.check("min shape", min_shape[0], "x shape", x_shape[-1],
|
|
731
|
+
validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
|
|
732
|
+
validator.check("min shape", min_shape, "max shape", max_shape, validator.EQ, self.name)
|
|
733
|
+
validator.check_int(len(min_shape), 1, validator.EQ, "min shape", self.name)
|
|
734
|
+
validator.check("min shape", min_shape[0], "x shape", x_shape[-1], validator.EQ, self.name)
|
|
711
735
|
return x_shape
|
|
712
736
|
|
|
713
737
|
def infer_dtype(self, x_type, min_type, max_type):
|
|
@@ -757,16 +781,16 @@ class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer):
|
|
|
757
781
|
num_bits=8,
|
|
758
782
|
narrow_range=False):
|
|
759
783
|
self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
|
|
760
|
-
self.num_bits = validator.check_int_range(self.num_bits, 2, 16,
|
|
784
|
+
self.num_bits = validator.check_int_range(self.num_bits, 2, 16, validator.INC_BOTH, 'num_bits', self.name)
|
|
761
785
|
self.narrow_range = validator.check_value_type(
|
|
762
786
|
'narrow_range', narrow_range, (bool,), self.name)
|
|
763
787
|
|
|
764
788
|
def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
|
|
765
|
-
validator.check_int(len(x_shape), 1,
|
|
766
|
-
validator.check("dout shape", dout_shape, "x shape", x_shape,
|
|
767
|
-
validator.check("min shape", min_shape, "max shape", max_shape,
|
|
768
|
-
validator.check_int(len(min_shape), 1,
|
|
769
|
-
validator.check("min shape", min_shape[0], "x shape", x_shape[-1],
|
|
789
|
+
validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
|
|
790
|
+
validator.check("dout shape", dout_shape, "x shape", x_shape, validator.EQ, self.name)
|
|
791
|
+
validator.check("min shape", min_shape, "max shape", max_shape, validator.EQ, self.name)
|
|
792
|
+
validator.check_int(len(min_shape), 1, validator.EQ, "min shape", self.name)
|
|
793
|
+
validator.check("min shape", min_shape[0], "x shape", x_shape[-1], validator.EQ, self.name)
|
|
770
794
|
return x_shape, min_shape, max_shape
|
|
771
795
|
|
|
772
796
|
def infer_dtype(self, dout_type, x_type, min_type, max_type):
|
|
@@ -855,15 +879,15 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
|
|
|
855
879
|
self.narrow_range = validator.check_value_type(
|
|
856
880
|
'narrow_range', narrow_range, (bool,), self.name)
|
|
857
881
|
self.training = validator.check_value_type('training', training, (bool,), self.name)
|
|
858
|
-
self.ema_decay = validator.check_float_range(ema_decay, 0, 1,
|
|
882
|
+
self.ema_decay = validator.check_float_range(ema_decay, 0, 1, validator.INC_BOTH, 'ema_decay', self.name)
|
|
859
883
|
self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
|
|
860
884
|
self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name)
|
|
861
885
|
self.init_prim_io_names(inputs=['x', 'min', 'max'],
|
|
862
886
|
outputs=['out'])
|
|
863
887
|
|
|
864
888
|
def infer_shape(self, x_shape, min_shape, max_shape):
|
|
865
|
-
validator.check_int(len(x_shape), 1,
|
|
866
|
-
validator.check("min shape", min_shape, "max shape", max_shape,
|
|
889
|
+
validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
|
|
890
|
+
validator.check("min shape", min_shape, "max shape", max_shape, validator.EQ, self.name)
|
|
867
891
|
validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
|
|
868
892
|
return x_shape
|
|
869
893
|
|
|
@@ -909,9 +933,9 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer):
|
|
|
909
933
|
|
|
910
934
|
def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
|
|
911
935
|
validator.check("dout shape", dout_shape, "x shape",
|
|
912
|
-
x_shape,
|
|
936
|
+
x_shape, validator.EQ, self.name)
|
|
913
937
|
validator.check("min shape", min_shape, "max shape",
|
|
914
|
-
max_shape,
|
|
938
|
+
max_shape, validator.EQ, self.name)
|
|
915
939
|
validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
|
|
916
940
|
return dout_shape
|
|
917
941
|
|
|
@@ -981,11 +1005,12 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
|
|
|
981
1005
|
'narrow_range', narrow_range, (bool,), self.name)
|
|
982
1006
|
self.training = validator.check_value_type(
|
|
983
1007
|
'training', training, (bool,), self.name)
|
|
984
|
-
self.ema_decay = validator.check_float_range(ema_decay, 0, 1,
|
|
1008
|
+
self.ema_decay = validator.check_float_range(ema_decay, 0, 1, validator.INC_BOTH, 'ema_decay', self.name)
|
|
985
1009
|
self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
|
|
986
1010
|
self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name)
|
|
987
1011
|
if self.is_ascend:
|
|
988
|
-
self.channel_axis = validator.check_int_range(channel_axis, 0, 1,
|
|
1012
|
+
self.channel_axis = validator.check_int_range(channel_axis, 0, 1, validator.INC_BOTH,
|
|
1013
|
+
'channel_axis', self.name)
|
|
989
1014
|
else:
|
|
990
1015
|
self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
|
|
991
1016
|
self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])
|
|
@@ -994,10 +1019,10 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
|
|
|
994
1019
|
if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
|
|
995
1020
|
raise ValueError(f"For '{self.name}' x rank must be in '{self.ascend_support_x_rank}'")
|
|
996
1021
|
if not self.is_ascend:
|
|
997
|
-
validator.check_int(len(x_shape), 1,
|
|
1022
|
+
validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
|
|
998
1023
|
if len(x_shape) == 1:
|
|
999
1024
|
self.channel_axis = 0
|
|
1000
|
-
validator.check("min shape", min_shape, "max shape", max_shape,
|
|
1025
|
+
validator.check("min shape", min_shape, "max shape", max_shape, validator.EQ, self.name)
|
|
1001
1026
|
validator.check_equal_int(min_shape[0], x_shape[self.channel_axis], "min shape", self.name)
|
|
1002
1027
|
validator.check_equal_int(max_shape[0], x_shape[self.channel_axis], "max shape", self.name)
|
|
1003
1028
|
return x_shape
|
|
@@ -1093,7 +1118,7 @@ class BatchNormFold(PrimitiveWithInfer):
|
|
|
1093
1118
|
@prim_attr_register
|
|
1094
1119
|
def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
|
|
1095
1120
|
"""Initialize batch norm fold layer"""
|
|
1096
|
-
self.momentum = validator.check_float_range(momentum, 0, 1,
|
|
1121
|
+
self.momentum = validator.check_float_range(momentum, 0, 1, validator.INC_BOTH, 'momentum', self.name)
|
|
1097
1122
|
self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
|
|
1098
1123
|
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
|
|
1099
1124
|
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
|
|
@@ -1102,8 +1127,9 @@ class BatchNormFold(PrimitiveWithInfer):
|
|
|
1102
1127
|
outputs=['batch_mean', 'batch_std', 'running_mean', 'running_std'])
|
|
1103
1128
|
|
|
1104
1129
|
def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape):
|
|
1105
|
-
validator.check("mean shape", mean_shape, "gamma_shape", variance_shape,
|
|
1106
|
-
validator.check("mean_shape[0]", mean_shape[0], "input channel",
|
|
1130
|
+
validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, validator.EQ, self.name)
|
|
1131
|
+
validator.check("mean_shape[0]", mean_shape[0], "input channel",
|
|
1132
|
+
x_shape[self.channel_axis], validator.EQ, self.name)
|
|
1107
1133
|
validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
|
|
1108
1134
|
return mean_shape, mean_shape, mean_shape, mean_shape
|
|
1109
1135
|
|
|
@@ -1144,13 +1170,13 @@ class BatchNormFoldGrad(PrimitiveWithInfer):
|
|
|
1144
1170
|
def infer_shape(self, d_batch_mean_shape, d_batch_std_shape, x_shape, batch_mean_shape, batch_std_shape,
|
|
1145
1171
|
global_step_shape):
|
|
1146
1172
|
validator.check("d_batch_mean shape", d_batch_mean_shape,
|
|
1147
|
-
"d_batch_std shape", d_batch_std_shape,
|
|
1173
|
+
"d_batch_std shape", d_batch_std_shape, validator.EQ, self.name)
|
|
1148
1174
|
validator.check("d_batch_mean shape", d_batch_mean_shape,
|
|
1149
|
-
"batch_mean shape", batch_mean_shape,
|
|
1175
|
+
"batch_mean shape", batch_mean_shape, validator.EQ, self.name)
|
|
1150
1176
|
validator.check("d_batch_mean shape", d_batch_mean_shape,
|
|
1151
|
-
"batch_std shape", batch_std_shape,
|
|
1177
|
+
"batch_std shape", batch_std_shape, validator.EQ, self.name)
|
|
1152
1178
|
validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0],
|
|
1153
|
-
"input channel", x_shape[self.channel_axis],
|
|
1179
|
+
"input channel", x_shape[self.channel_axis], validator.EQ, self.name)
|
|
1154
1180
|
validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
|
|
1155
1181
|
return x_shape
|
|
1156
1182
|
|
|
@@ -1195,9 +1221,10 @@ class CorrectionMul(PrimitiveWithInfer):
|
|
|
1195
1221
|
outputs=['out'])
|
|
1196
1222
|
|
|
1197
1223
|
def infer_shape(self, x_shape, batch_std_shape, running_std_shape):
|
|
1198
|
-
validator.check("batch_std shape", batch_std_shape, "running_std shape",
|
|
1224
|
+
validator.check("batch_std shape", batch_std_shape, "running_std shape",
|
|
1225
|
+
running_std_shape, validator.EQ, self.name)
|
|
1199
1226
|
validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
|
|
1200
|
-
|
|
1227
|
+
validator.EQ, self.name)
|
|
1201
1228
|
return x_shape
|
|
1202
1229
|
|
|
1203
1230
|
def infer_dtype(self, x_type, batch_std_type, running_std_type):
|
|
@@ -1229,11 +1256,11 @@ class CorrectionMulGrad(PrimitiveWithInfer):
|
|
|
1229
1256
|
outputs=['dx', 'mul_dx'])
|
|
1230
1257
|
|
|
1231
1258
|
def infer_shape(self, dout_shape, x_shape, gamma_shape, running_std_shape):
|
|
1232
|
-
validator.check("dout shape", dout_shape, "x_shape x", x_shape,
|
|
1259
|
+
validator.check("dout shape", dout_shape, "x_shape x", x_shape, validator.EQ, self.name)
|
|
1233
1260
|
validator.check("gamma_shape[0]", gamma_shape[0], "dout channel size", dout_shape[self.channel_axis],
|
|
1234
|
-
|
|
1261
|
+
validator.EQ, self.name)
|
|
1235
1262
|
validator.check("running_std_shape[0]", running_std_shape[0],
|
|
1236
|
-
"dout channel size", dout_shape[self.channel_axis],
|
|
1263
|
+
"dout channel size", dout_shape[self.channel_axis], validator.EQ, self.name)
|
|
1237
1264
|
if context.get_context('device_target') == "Ascend":
|
|
1238
1265
|
return x_shape, x_shape
|
|
1239
1266
|
return x_shape, gamma_shape
|
|
@@ -1319,14 +1346,16 @@ class BatchNormFold2(PrimitiveWithInfer):
|
|
|
1319
1346
|
|
|
1320
1347
|
def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape,
|
|
1321
1348
|
running_mean_shape, global_step_shape):
|
|
1322
|
-
validator.check("batch_std shape", batch_std_shape, "running_std shape",
|
|
1323
|
-
|
|
1324
|
-
validator.check("batch_std shape", batch_std_shape, "
|
|
1349
|
+
validator.check("batch_std shape", batch_std_shape, "running_std shape",
|
|
1350
|
+
running_std_shape, validator.EQ, self.name)
|
|
1351
|
+
validator.check("batch_std shape", batch_std_shape, "batch_mean shape",
|
|
1352
|
+
batch_mean_shape, validator.EQ, self.name)
|
|
1353
|
+
validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, validator.EQ, self.name)
|
|
1325
1354
|
validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape,
|
|
1326
|
-
|
|
1327
|
-
validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape,
|
|
1355
|
+
validator.EQ, self.name)
|
|
1356
|
+
validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, validator.EQ, self.name)
|
|
1328
1357
|
validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
|
|
1329
|
-
|
|
1358
|
+
validator.EQ, self.name)
|
|
1330
1359
|
validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
|
|
1331
1360
|
return x_shape
|
|
1332
1361
|
|
|
@@ -1369,13 +1398,15 @@ class BatchNormFold2Grad(PrimitiveWithInfer):
|
|
|
1369
1398
|
def infer_shape(self, dout_shape, x_shape, gamma_shape,
|
|
1370
1399
|
batch_std_shape, batch_mean_shape,
|
|
1371
1400
|
running_std_shape, running_mean_shape, global_step_shape):
|
|
1372
|
-
validator.check("batch_std shape", batch_std_shape, "batch_mean shape",
|
|
1373
|
-
|
|
1401
|
+
validator.check("batch_std shape", batch_std_shape, "batch_mean shape",
|
|
1402
|
+
batch_mean_shape, validator.EQ, self.name)
|
|
1403
|
+
validator.check("batch_std shape", batch_std_shape, "running_std shape",
|
|
1404
|
+
running_std_shape, validator.EQ, self.name)
|
|
1374
1405
|
validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape,
|
|
1375
|
-
|
|
1376
|
-
validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape,
|
|
1406
|
+
validator.EQ, self.name)
|
|
1407
|
+
validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, validator.EQ, self.name)
|
|
1377
1408
|
validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis],
|
|
1378
|
-
|
|
1409
|
+
validator.EQ, self.name)
|
|
1379
1410
|
validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
|
|
1380
1411
|
return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape
|
|
1381
1412
|
|
|
@@ -1406,7 +1437,7 @@ class BatchNormFoldD(PrimitiveWithInfer):
|
|
|
1406
1437
|
def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
|
|
1407
1438
|
"""Initialize _BatchNormFold layer"""
|
|
1408
1439
|
from mindspore.ops._op_impl._custom_op import batchnorm_fold
|
|
1409
|
-
self.momentum = validator.check_float_range(momentum, 0, 1,
|
|
1440
|
+
self.momentum = validator.check_float_range(momentum, 0, 1, validator.INC_BOTH, 'momentum', self.name)
|
|
1410
1441
|
self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
|
|
1411
1442
|
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
|
|
1412
1443
|
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
|
|
@@ -1416,8 +1447,8 @@ class BatchNormFoldD(PrimitiveWithInfer):
|
|
|
1416
1447
|
'mean_updated', 'variance_updated'])
|
|
1417
1448
|
|
|
1418
1449
|
def infer_shape(self, x_shape, x_sum_shape, x_square_sum_shape, mean_shape, variance_shape):
|
|
1419
|
-
validator.check("mean shape", mean_shape, "gamma_shape", variance_shape,
|
|
1420
|
-
validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[1],
|
|
1450
|
+
validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, validator.EQ, self.name)
|
|
1451
|
+
validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[1], validator.EQ, self.name)
|
|
1421
1452
|
return x_shape, mean_shape, mean_shape, mean_shape, mean_shape, mean_shape, mean_shape
|
|
1422
1453
|
|
|
1423
1454
|
def infer_dtype(self, x_type, x_sum_type, x_square_sum_type, mean_type, variance_type):
|
|
@@ -1487,12 +1518,14 @@ class BatchNormFold2D(PrimitiveWithInfer):
|
|
|
1487
1518
|
outputs=['y'])
|
|
1488
1519
|
|
|
1489
1520
|
def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape):
|
|
1490
|
-
validator.check("batch_std shape", batch_std_shape, "running_std shape",
|
|
1491
|
-
|
|
1492
|
-
validator.check("batch_std shape", batch_std_shape, "
|
|
1493
|
-
|
|
1521
|
+
validator.check("batch_std shape", batch_std_shape, "running_std shape",
|
|
1522
|
+
running_std_shape, validator.EQ, self.name)
|
|
1523
|
+
validator.check("batch_std shape", batch_std_shape, "batch_mean shape",
|
|
1524
|
+
batch_mean_shape, validator.EQ, self.name)
|
|
1525
|
+
validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, validator.EQ, self.name)
|
|
1526
|
+
validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, validator.EQ, self.name)
|
|
1494
1527
|
validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
|
|
1495
|
-
|
|
1528
|
+
validator.EQ, self.name)
|
|
1496
1529
|
return x_shape
|
|
1497
1530
|
|
|
1498
1531
|
def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type):
|
|
@@ -1517,11 +1550,13 @@ class BatchNormFold2GradD(PrimitiveWithInfer):
|
|
|
1517
1550
|
|
|
1518
1551
|
def infer_shape(self, dout_shape, dout_reduce_shape, dout_x_reduce_shape, gamma_shape, batch_std_shape,
|
|
1519
1552
|
batch_mean_shape, running_std_shape):
|
|
1520
|
-
validator.check("batch_std shape", batch_std_shape, "batch_mean shape",
|
|
1521
|
-
|
|
1522
|
-
validator.check("batch_std shape", batch_std_shape, "
|
|
1553
|
+
validator.check("batch_std shape", batch_std_shape, "batch_mean shape",
|
|
1554
|
+
batch_mean_shape, validator.EQ, self.name)
|
|
1555
|
+
validator.check("batch_std shape", batch_std_shape, "running_std shape",
|
|
1556
|
+
running_std_shape, validator.EQ, self.name)
|
|
1557
|
+
validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, validator.EQ, self.name)
|
|
1523
1558
|
validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis],
|
|
1524
|
-
|
|
1559
|
+
validator.EQ, self.name)
|
|
1525
1560
|
return gamma_shape, gamma_shape, gamma_shape, dout_shape
|
|
1526
1561
|
|
|
1527
1562
|
def infer_dtype(self, dout_type, dout_reduce_type, dout_x_reduce_type, gamma_type, batch_std_type,
|
|
@@ -1553,7 +1588,7 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer):
|
|
|
1553
1588
|
outputs=['dout_reduce', 'dout_x_reduce'])
|
|
1554
1589
|
|
|
1555
1590
|
def infer_shape(self, dout_shape, x_shape):
|
|
1556
|
-
validator.check("dout shape", dout_shape, "x shape", x_shape,
|
|
1591
|
+
validator.check("dout shape", dout_shape, "x shape", x_shape, validator.EQ, self.name)
|
|
1557
1592
|
return (dout_shape[self.channel_axis],), (dout_shape[self.channel_axis],)
|
|
1558
1593
|
|
|
1559
1594
|
def infer_dtype(self, dout_type, x_type):
|
|
@@ -1595,17 +1630,17 @@ class ActsULQ(PrimitiveWithInfer):
|
|
|
1595
1630
|
def __init__(self, fixed_min=False, num_bits=8):
|
|
1596
1631
|
validator.check_value_type("fixed_min", fixed_min, [bool], self.name)
|
|
1597
1632
|
validator.check_value_type("num_bits", num_bits, [int], self.name)
|
|
1598
|
-
validator.check_int(num_bits, 8,
|
|
1633
|
+
validator.check_int(num_bits, 8, validator.EQ, "value of num_bits", self.name)
|
|
1599
1634
|
|
|
1600
1635
|
def infer_shape(self, x_shape, clamp_min_shape, clamp_max_shape):
|
|
1601
1636
|
"""infer shape of primitive"""
|
|
1602
|
-
validator.check_int(len(clamp_min_shape), len(x_shape),
|
|
1603
|
-
validator.check_int(len(clamp_max_shape), len(x_shape),
|
|
1637
|
+
validator.check_int(len(clamp_min_shape), len(x_shape), validator.EQ, "dims of clamp_min", self.name)
|
|
1638
|
+
validator.check_int(len(clamp_max_shape), len(x_shape), validator.EQ, "dims of clamp_max", self.name)
|
|
1604
1639
|
|
|
1605
1640
|
x_shape_len = len(x_shape)
|
|
1606
1641
|
for i in range(x_shape_len):
|
|
1607
|
-
validator.check_int(clamp_min_shape[i], 1,
|
|
1608
|
-
validator.check_int(clamp_max_shape[i], 1,
|
|
1642
|
+
validator.check_int(clamp_min_shape[i], 1, validator.EQ, "dims of clamp_min", self.name)
|
|
1643
|
+
validator.check_int(clamp_max_shape[i], 1, validator.EQ, "dims of clamp_max", self.name)
|
|
1609
1644
|
|
|
1610
1645
|
return x_shape, x_shape, x_shape, x_shape
|
|
1611
1646
|
|
|
@@ -1746,12 +1781,12 @@ class WtsARQ(PrimitiveWithInfer):
|
|
|
1746
1781
|
@prim_attr_register
|
|
1747
1782
|
def __init__(self, num_bits, offset_flag):
|
|
1748
1783
|
validator.check_value_type("num_bits", num_bits, [int], self.name)
|
|
1749
|
-
validator.check_int(num_bits, 8,
|
|
1784
|
+
validator.check_int(num_bits, 8, validator.EQ, "value of num_bits", self.name)
|
|
1750
1785
|
validator.check_value_type("offset_flag", offset_flag, [bool], self.name)
|
|
1751
1786
|
|
|
1752
1787
|
def infer_shape(self, w_shape, w_min_shape, w_max_shape):
|
|
1753
|
-
validator.check_int(len(w_min_shape), len(w_shape),
|
|
1754
|
-
validator.check_int(len(w_max_shape), len(w_shape),
|
|
1788
|
+
validator.check_int(len(w_min_shape), len(w_shape), validator.EQ, "dims of w_min", self.name)
|
|
1789
|
+
validator.check_int(len(w_max_shape), len(w_shape), validator.EQ, "dims of w_max", self.name)
|
|
1755
1790
|
return w_shape
|
|
1756
1791
|
|
|
1757
1792
|
def infer_dtype(self, w_dtype, w_min_dtype, w_max_dtype):
|
|
@@ -1808,6 +1843,6 @@ class IFMR(Primitive):
|
|
|
1808
1843
|
validator.check_value_type("search_range", search_range, [list, tuple], self.name)
|
|
1809
1844
|
for item in search_range:
|
|
1810
1845
|
validator.check_positive_float(item, "item of search_range", self.name)
|
|
1811
|
-
validator.check('search_range[1]', search_range[1], 'search_range[0]', search_range[0],
|
|
1846
|
+
validator.check('search_range[1]', search_range[1], 'search_range[0]', search_range[0], validator.GE, self.name)
|
|
1812
1847
|
validator.check_value_type("search_step", search_step, [float], self.name)
|
|
1813
1848
|
validator.check_value_type("offset_flag", with_offset, [bool], self.name)
|