mindspore 1.10.0__cp37-cp37m-win_amd64.whl → 2.0.0rc1__cp37-cp37m-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/ConcurrencyCheck.dll +0 -0
- mindspore/CppBuildInsights.dll +0 -0
- mindspore/CppCoreCheck.dll +0 -0
- mindspore/EnumIndex.dll +0 -0
- mindspore/EspXEngine.dll +0 -0
- mindspore/HResultCheck.dll +0 -0
- mindspore/KernelTraceControl.dll +0 -0
- mindspore/LocalESPC.dll +0 -0
- mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
- mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
- mindspore/VariantClear.dll +0 -0
- mindspore/__init__.py +9 -4
- mindspore/_c_dataengine.cp37-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp37-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp37-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/builtin_operations.py +32 -4
- mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +17 -2
- mindspore/_extends/parse/parser.py +193 -34
- mindspore/_extends/parse/resources.py +7 -8
- mindspore/_extends/parse/standard_method.py +1780 -435
- mindspore/_extends/parse/trope.py +3 -1
- mindspore/amp.py +53 -58
- mindspore/atlprov.dll +0 -0
- mindspore/boost/adasum.py +3 -2
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +46 -26
- mindspore/boost/dim_reduce.py +6 -5
- mindspore/boost/grad_accumulation.py +2 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/cfgpersist.dll +0 -0
- mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
- mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
- mindspore/common/__init__.py +11 -10
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +57 -0
- mindspore/common/api.py +582 -297
- mindspore/common/dtype.py +66 -18
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +38 -1
- mindspore/common/jit_config.py +25 -13
- mindspore/common/mutable.py +53 -24
- mindspore/common/parameter.py +60 -37
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +927 -0
- mindspore/common/tensor.py +1627 -3900
- mindspore/communication/__init__.py +10 -5
- mindspore/communication/_comm_helper.py +78 -214
- mindspore/communication/_hccl_management.py +2 -1
- mindspore/communication/management.py +136 -47
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +291 -56
- mindspore/d3dcompiler_47.dll +0 -0
- mindspore/dataset/__init__.py +12 -8
- mindspore/dataset/audio/__init__.py +9 -9
- mindspore/dataset/audio/transforms.py +1090 -228
- mindspore/dataset/audio/utils.py +87 -39
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +17 -15
- mindspore/dataset/core/config.py +246 -17
- mindspore/dataset/core/py_util_helpers.py +4 -3
- mindspore/dataset/core/validator_helpers.py +10 -10
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +9 -9
- mindspore/dataset/engine/datasets.py +648 -477
- mindspore/dataset/engine/datasets_audio.py +165 -167
- mindspore/dataset/engine/datasets_standard_format.py +93 -67
- mindspore/dataset/engine/datasets_text.py +492 -342
- mindspore/dataset/engine/datasets_user_defined.py +85 -50
- mindspore/dataset/engine/datasets_vision.py +1224 -699
- mindspore/dataset/engine/graphdata.py +134 -69
- mindspore/dataset/engine/iterators.py +50 -9
- mindspore/dataset/engine/offload.py +52 -31
- mindspore/dataset/engine/samplers.py +27 -24
- mindspore/dataset/engine/serializer_deserializer.py +14 -15
- mindspore/dataset/engine/validators.py +213 -52
- mindspore/dataset/text/__init__.py +10 -8
- mindspore/dataset/text/transforms.py +152 -57
- mindspore/dataset/text/utils.py +98 -49
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +4 -2
- mindspore/dataset/transforms/c_transforms.py +11 -13
- mindspore/dataset/transforms/py_transforms.py +2 -2
- mindspore/dataset/transforms/py_transforms_util.py +10 -0
- mindspore/dataset/transforms/transforms.py +13 -15
- mindspore/dataset/transforms/validators.py +7 -7
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/browse_dataset.py +13 -13
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +8 -7
- mindspore/dataset/vision/c_transforms.py +125 -126
- mindspore/dataset/vision/py_transforms.py +37 -37
- mindspore/dataset/vision/py_transforms_util.py +23 -20
- mindspore/dataset/vision/transforms.py +316 -315
- mindspore/dataset/vision/utils.py +313 -17
- mindspore/dataset/vision/validators.py +6 -6
- mindspore/default_config.py +0 -1
- mindspore/dpcmi.dll +0 -0
- mindspore/{compression → experimental}/__init__.py +6 -5
- mindspore/experimental/map_parameter.py +275 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +70 -9
- mindspore/include/api/delegate.h +8 -1
- mindspore/include/api/dual_abi_helper.h +8 -24
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_group.h +68 -0
- mindspore/include/api/model_parallel_runner.h +17 -17
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +20 -4
- mindspore/include/api/status.h +7 -1
- mindspore/include/api/types.h +25 -21
- mindspore/include/api/visible.h +4 -0
- mindspore/include/c_api/model_c.h +5 -0
- mindspore/include/c_api/status_c.h +1 -1
- mindspore/include/dataset/config.h +1 -1
- mindspore/include/dataset/constants.h +14 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/include/dataset/vision.h +56 -117
- mindspore/include/dataset/vision_lite.h +102 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +28 -28
- mindspore/mindrecord/common/exceptions.py +2 -4
- mindspore/mindrecord/filereader.py +19 -1
- mindspore/mindrecord/filewriter.py +250 -88
- mindspore/mindrecord/mindpage.py +13 -13
- mindspore/mindrecord/shardheader.py +15 -15
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +29 -29
- mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
- mindspore/mindrecord/tools/csv_to_mr.py +4 -4
- mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
- mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/{libmindspore_backend.dll → mindspore_backend.dll} +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +1 -5
- mindspore/nn/cell.py +297 -234
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +17 -42
- mindspore/nn/layer/__init__.py +7 -4
- mindspore/nn/layer/activation.py +131 -88
- mindspore/nn/layer/basic.py +313 -613
- mindspore/nn/layer/channel_shuffle.py +103 -0
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +52 -6
- mindspore/nn/layer/conv.py +112 -43
- mindspore/nn/layer/dense.py +10 -9
- mindspore/nn/layer/embedding.py +36 -34
- mindspore/nn/layer/image.py +123 -27
- mindspore/nn/layer/math.py +108 -107
- mindspore/nn/layer/normalization.py +212 -366
- mindspore/nn/layer/padding.py +370 -42
- mindspore/nn/layer/pooling.py +1443 -219
- mindspore/nn/layer/rnn_cells.py +11 -16
- mindspore/nn/layer/rnns.py +38 -39
- mindspore/nn/layer/thor_layer.py +24 -25
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +9 -6
- mindspore/nn/loss/loss.py +678 -142
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
- mindspore/nn/optim/ada_grad.py +8 -8
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +18 -14
- mindspore/nn/optim/adam.py +429 -87
- mindspore/nn/optim/adamax.py +5 -6
- mindspore/nn/optim/adasum.py +10 -8
- mindspore/nn/optim/asgd.py +7 -7
- mindspore/nn/optim/ftrl.py +81 -11
- mindspore/nn/optim/lamb.py +7 -8
- mindspore/nn/optim/lars.py +4 -4
- mindspore/nn/optim/lazyadam.py +82 -7
- mindspore/nn/optim/momentum.py +8 -7
- mindspore/nn/optim/optimizer.py +19 -10
- mindspore/nn/optim/proximal_ada_grad.py +6 -5
- mindspore/nn/optim/rmsprop.py +3 -3
- mindspore/nn/optim/rprop.py +20 -16
- mindspore/nn/optim/sgd.py +21 -15
- mindspore/nn/optim/thor.py +23 -21
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -6
- mindspore/nn/probability/bijector/invert.py +4 -2
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/__init__.py +6 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
- mindspore/nn/probability/distribution/_utils/utils.py +11 -17
- mindspore/nn/probability/distribution/bernoulli.py +6 -6
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +9 -9
- mindspore/nn/probability/distribution/cauchy.py +8 -8
- mindspore/nn/probability/distribution/distribution.py +12 -6
- mindspore/nn/probability/distribution/exponential.py +5 -5
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +6 -5
- mindspore/nn/probability/distribution/gumbel.py +5 -5
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +0 -1
- mindspore/nn/probability/distribution/logistic.py +4 -5
- mindspore/nn/probability/distribution/normal.py +11 -15
- mindspore/nn/probability/distribution/poisson.py +6 -2
- mindspore/nn/probability/distribution/student_t.py +150 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +5 -5
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +8 -1
- mindspore/nn/wrap/cell_wrapper.py +55 -27
- mindspore/nn/wrap/grad_reducer.py +20 -11
- mindspore/nn/wrap/loss_scale.py +47 -30
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +46 -42
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +26 -19
- mindspore/numpy/utils.py +1 -8
- mindspore/numpy/utils_const.py +112 -62
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -3
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +209 -152
- mindspore/ops/_grad/grad_base.py +55 -17
- mindspore/ops/_grad/grad_clip_ops.py +11 -3
- mindspore/ops/_grad/grad_comm_ops.py +58 -47
- mindspore/ops/_grad/grad_implementations.py +21 -61
- mindspore/ops/_grad/grad_inner_ops.py +48 -6
- mindspore/ops/_grad/grad_math_ops.py +306 -161
- mindspore/ops/_grad/grad_nn_ops.py +192 -181
- mindspore/ops/_grad/grad_other_ops.py +1 -1
- mindspore/ops/_grad/grad_quant_ops.py +5 -5
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +15 -9
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
- mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
- mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
- mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
- mindspore/ops/_op_impl/__init__.py +3 -3
- mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
- mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
- mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/__init__.py +1 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
- mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -608
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/greater.py +2 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
- mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
- mindspore/ops/_op_impl/tbe/slice.py +26 -15
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +3 -2
- mindspore/ops/_register_for_op.py +11 -0
- mindspore/ops/_utils/__init__.py +1 -1
- mindspore/ops/_utils/utils.py +20 -41
- mindspore/ops/_vmap/__init__.py +2 -2
- mindspore/ops/_vmap/vmap_array_ops.py +170 -78
- mindspore/ops/_vmap/vmap_base.py +24 -10
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
- mindspore/ops/_vmap/vmap_image_ops.py +52 -0
- mindspore/ops/_vmap/vmap_math_ops.py +77 -6
- mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
- mindspore/ops/_vmap/vmap_other_ops.py +3 -1
- mindspore/ops/_vmap/vmap_random_ops.py +55 -3
- mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
- mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/__init__.py +1 -4
- mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
- mindspore/ops/composite/__init__.py +12 -13
- mindspore/ops/composite/base.py +261 -254
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +197 -156
- mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
- mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
- mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
- mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
- mindspore/ops/function/__init__.py +323 -8
- mindspore/ops/function/array_func.py +3511 -780
- mindspore/ops/function/clip_func.py +329 -0
- mindspore/ops/function/debug_func.py +6 -6
- mindspore/ops/function/grad/__init__.py +5 -1
- mindspore/ops/function/grad/grad_func.py +736 -65
- mindspore/ops/function/image_func.py +270 -0
- mindspore/ops/function/linalg_func.py +268 -8
- mindspore/ops/function/math_func.py +8032 -3164
- mindspore/ops/function/nn_func.py +5619 -1855
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +11 -10
- mindspore/ops/function/random_func.py +939 -77
- mindspore/ops/function/sparse_func.py +249 -84
- mindspore/ops/function/sparse_unary_func.py +2303 -0
- mindspore/ops/function/spectral_func.py +146 -0
- mindspore/ops/function/vmap_func.py +114 -0
- mindspore/ops/functional.py +182 -254
- mindspore/ops/op_info_register.py +79 -34
- mindspore/ops/operations/__init__.py +210 -118
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +25 -15
- mindspore/ops/operations/_grad_ops.py +447 -322
- mindspore/ops/operations/_inner_ops.py +547 -176
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +29 -27
- mindspore/ops/operations/_ocr_ops.py +11 -11
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_quant_ops.py +186 -101
- mindspore/ops/operations/_rl_inner_ops.py +122 -61
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1047 -0
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +4 -4
- mindspore/ops/operations/array_ops.py +1428 -1226
- mindspore/ops/operations/comm_ops.py +180 -117
- mindspore/ops/operations/control_ops.py +4 -2
- mindspore/ops/operations/custom_ops.py +185 -98
- mindspore/ops/operations/debug_ops.py +92 -54
- mindspore/ops/operations/image_ops.py +406 -211
- mindspore/ops/operations/inner_ops.py +42 -53
- mindspore/ops/operations/linalg_ops.py +32 -29
- mindspore/ops/operations/math_ops.py +2076 -897
- mindspore/ops/operations/nn_ops.py +1282 -1252
- mindspore/ops/operations/other_ops.py +124 -278
- mindspore/ops/operations/random_ops.py +345 -178
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +502 -157
- mindspore/ops/operations/spectral_ops.py +107 -0
- mindspore/ops/primitive.py +192 -15
- mindspore/ops/vm_impl_registry.py +23 -2
- mindspore/parallel/__init__.py +6 -1
- mindspore/parallel/_auto_parallel_context.py +199 -92
- mindspore/parallel/_cell_wrapper.py +4 -2
- mindspore/parallel/_cost_model_context.py +3 -0
- mindspore/parallel/_dp_allreduce_fusion.py +2 -1
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +167 -28
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +9 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
- mindspore/parallel/_utils.py +47 -7
- mindspore/parallel/algo_parameter_config.py +5 -1
- mindspore/parallel/checkpoint_transform.py +329 -0
- mindspore/parallel/shard.py +229 -0
- mindspore/perf_msvcbuildinsights.dll +0 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/util.py +4 -3
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +249 -0
- mindspore/profiler/parser/aicpu_data_parser.py +38 -39
- mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
- mindspore/profiler/parser/base_timeline_generator.py +471 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
- mindspore/profiler/parser/framework_parser.py +42 -16
- mindspore/profiler/parser/hccl_parser.py +158 -158
- mindspore/profiler/parser/hwts_log_parser.py +7 -6
- mindspore/profiler/parser/integrator.py +18 -1579
- mindspore/profiler/parser/minddata_analyzer.py +8 -8
- mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +108 -0
- mindspore/profiler/parser/step_trace_parser.py +1 -1
- mindspore/profiler/profiling.py +396 -194
- mindspore/rewrite/__init__.py +6 -2
- mindspore/rewrite/api/node.py +51 -110
- mindspore/rewrite/api/node_type.py +10 -6
- mindspore/rewrite/api/pattern_engine.py +51 -7
- mindspore/rewrite/api/scoped_value.py +64 -53
- mindspore/rewrite/api/symbol_tree.py +108 -61
- mindspore/rewrite/api/tree_node_helper.py +2 -3
- mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
- mindspore/rewrite/ast_helpers/__init__.py +6 -3
- mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
- mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
- mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
- mindspore/rewrite/ast_transformers/__init__.py +0 -1
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
- mindspore/rewrite/common/__init__.py +2 -0
- mindspore/rewrite/common/event.py +1 -1
- mindspore/rewrite/common/observable.py +1 -1
- mindspore/rewrite/common/observer.py +1 -1
- mindspore/rewrite/common/rewrite_elog.py +35 -0
- mindspore/rewrite/namer.py +2 -2
- mindspore/rewrite/namespace.py +14 -4
- mindspore/rewrite/node.py +161 -13
- mindspore/rewrite/parser.py +0 -1
- mindspore/rewrite/parser_register.py +0 -1
- mindspore/rewrite/parsers/arguments_parser.py +3 -2
- mindspore/rewrite/parsers/assign_parser.py +267 -67
- mindspore/rewrite/parsers/attribute_parser.py +56 -0
- mindspore/rewrite/parsers/class_def_parser.py +191 -108
- mindspore/rewrite/parsers/constant_parser.py +101 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/for_parser.py +28 -15
- mindspore/rewrite/parsers/function_def_parser.py +21 -5
- mindspore/rewrite/parsers/if_parser.py +11 -28
- mindspore/rewrite/parsers/module_parser.py +9 -6
- mindspore/rewrite/parsers/return_parser.py +3 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +322 -109
- mindspore/rewrite/symbol_tree_builder.py +45 -8
- mindspore/rewrite/symbol_tree_dumper.py +0 -1
- mindspore/rewrite/topological_manager.py +1 -2
- mindspore/run_check/_check_version.py +209 -112
- mindspore/run_check/run_check.py +2 -1
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +6 -4
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +321 -50
- mindspore/train/callback/__init__.py +3 -1
- mindspore/train/callback/_backup_and_restore.py +120 -0
- mindspore/train/callback/_callback.py +8 -8
- mindspore/train/callback/_checkpoint.py +12 -9
- mindspore/train/callback/_early_stop.py +13 -7
- mindspore/train/callback/_history.py +8 -8
- mindspore/train/callback/_lambda_callback.py +6 -6
- mindspore/train/callback/_landscape.py +36 -38
- mindspore/train/callback/_loss_monitor.py +12 -6
- mindspore/train/callback/_lr_scheduler_callback.py +2 -4
- mindspore/train/callback/_on_request_exit.py +212 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
- mindspore/train/callback/_summary_collector.py +27 -19
- mindspore/train/callback/_time_monitor.py +13 -7
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +122 -33
- mindspore/train/dataset_helper.py +28 -87
- mindspore/train/loss_scale_manager.py +4 -7
- mindspore/{nn → train}/metrics/__init__.py +20 -20
- mindspore/{nn → train}/metrics/accuracy.py +12 -10
- mindspore/{nn → train}/metrics/auc.py +4 -4
- mindspore/{nn → train}/metrics/bleu_score.py +4 -4
- mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
- mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
- mindspore/{nn → train}/metrics/dice.py +6 -5
- mindspore/{nn → train}/metrics/error.py +7 -5
- mindspore/{nn → train}/metrics/fbeta.py +9 -7
- mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
- mindspore/{nn → train}/metrics/loss.py +4 -3
- mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/metric.py +6 -5
- mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
- mindspore/{nn → train}/metrics/perplexity.py +5 -4
- mindspore/{nn → train}/metrics/precision.py +5 -4
- mindspore/{nn → train}/metrics/recall.py +5 -4
- mindspore/{nn → train}/metrics/roc.py +7 -6
- mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/topk.py +7 -5
- mindspore/train/mind_ir_pb2.py +339 -32
- mindspore/train/model.py +113 -84
- mindspore/train/serialization.py +547 -167
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -12
- mindspore/train/train_thor/convert_utils.py +7 -1
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/train/train_thor/model_thor.py +0 -4
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +901 -660
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -514
- mindspore/compression/quant/qat.py +0 -636
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/libatomic-1.dll +0 -0
- mindspore/libgcc_s_seh-1.dll +0 -0
- mindspore/libgfortran-4.dll +0 -0
- mindspore/libgomp-1.dll +0 -0
- mindspore/libjpeg-62.dll +0 -0
- mindspore/libmindspore.dll +0 -0
- mindspore/libmindspore_common.dll +0 -0
- mindspore/libmindspore_core.dll +0 -0
- mindspore/libmindspore_glog.dll +0 -0
- mindspore/libnnacl.dll +0 -0
- mindspore/libopencv_core452.dll +0 -0
- mindspore/libopencv_imgcodecs452.dll +0 -0
- mindspore/libopencv_imgproc452.dll +0 -0
- mindspore/libquadmath-0.dll +0 -0
- mindspore/libsqlite3.dll +0 -0
- mindspore/libssp-0.dll +0 -0
- mindspore/libstdc++-6.dll +0 -0
- mindspore/libtinyxml2.dll +0 -0
- mindspore/libturbojpeg.dll +0 -0
- mindspore/libwinpthread-1.dll +0 -0
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -138
- mindspore/nn/probability/dpn/vae/vae.py +0 -122
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
- mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
- mindspore/ops/composite/array_ops.py +0 -210
- mindspore/ops/composite/clip_ops.py +0 -238
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/ops/operations/sponge_ops.py +0 -3531
- mindspore/ops/operations/sponge_update_ops.py +0 -2546
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- mindspore/run_check/_check_deps_version.py +0 -84
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
mindspore/_checkparam.py
CHANGED
|
@@ -18,10 +18,8 @@ from __future__ import absolute_import
|
|
|
18
18
|
import re
|
|
19
19
|
import inspect
|
|
20
20
|
import math
|
|
21
|
-
from enum import Enum
|
|
22
21
|
from functools import reduce, wraps
|
|
23
|
-
from itertools import repeat
|
|
24
|
-
from collections import deque
|
|
22
|
+
from itertools import repeat
|
|
25
23
|
from collections.abc import Iterable
|
|
26
24
|
import numpy as np
|
|
27
25
|
|
|
@@ -31,71 +29,92 @@ from mindspore.common import dtype as mstype
|
|
|
31
29
|
from mindspore._c_expression import Tensor as Tensor_
|
|
32
30
|
|
|
33
31
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
return
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
32
|
+
EQ = 1 # ==
|
|
33
|
+
NE = 2 # !=
|
|
34
|
+
LT = 3 # <
|
|
35
|
+
LE = 4 # <=
|
|
36
|
+
GT = 5 # >
|
|
37
|
+
GE = 6 # >=
|
|
38
|
+
# scalar range check
|
|
39
|
+
INC_NEITHER = 7 # (), include neither
|
|
40
|
+
INC_LEFT = 8 # [), include left
|
|
41
|
+
INC_RIGHT = 9 # (], include right
|
|
42
|
+
INC_BOTH = 10 # [], include both
|
|
43
|
+
# collection in, not in
|
|
44
|
+
IN = 11
|
|
45
|
+
NOT_IN = 12
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _check_binary_rel(val1, val2, rel):
|
|
49
|
+
"""check binary relation"""
|
|
50
|
+
if rel == EQ:
|
|
51
|
+
return val1 == val2
|
|
52
|
+
if rel == NE:
|
|
53
|
+
return val1 != val2
|
|
54
|
+
if rel == LT:
|
|
55
|
+
return val1 < val2
|
|
56
|
+
if rel == LE:
|
|
57
|
+
return val1 <= val2
|
|
58
|
+
if rel == GT:
|
|
59
|
+
return val1 > val2
|
|
60
|
+
if rel == GE:
|
|
61
|
+
return val1 >= val2
|
|
62
|
+
if rel == IN:
|
|
63
|
+
return val1 in val2
|
|
64
|
+
if rel == NOT_IN:
|
|
65
|
+
return val1 not in val2
|
|
66
|
+
|
|
67
|
+
return False
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _check_inc_rel(val, lower, upper, rel):
|
|
71
|
+
"""check include relation"""
|
|
72
|
+
if rel == INC_NEITHER:
|
|
73
|
+
return not (val <= lower or val >= upper)
|
|
74
|
+
if rel == INC_LEFT:
|
|
75
|
+
return not (val < lower or val >= upper)
|
|
76
|
+
if rel == INC_RIGHT:
|
|
77
|
+
return not (val <= lower or val > upper)
|
|
78
|
+
if rel == INC_BOTH:
|
|
79
|
+
return not (val < lower or val > upper)
|
|
80
|
+
|
|
81
|
+
return False
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _format_str_one_value(value, rel):
|
|
85
|
+
"""format string"""
|
|
86
|
+
if rel == EQ:
|
|
87
|
+
return "= {}".format(value)
|
|
88
|
+
if rel == NE:
|
|
89
|
+
return "!= {}".format(value)
|
|
90
|
+
if rel == LT:
|
|
91
|
+
return "< {}".format(value)
|
|
92
|
+
if rel == LE:
|
|
93
|
+
return "<= {}".format(value)
|
|
94
|
+
if rel == GT:
|
|
95
|
+
return "> {}".format(value)
|
|
96
|
+
if rel == GE:
|
|
97
|
+
return ">= {}".format(value)
|
|
98
|
+
if rel == IN:
|
|
99
|
+
return "in {}".format(value)
|
|
100
|
+
if rel == NOT_IN:
|
|
101
|
+
return "not in {}".format(value)
|
|
102
|
+
|
|
103
|
+
return ""
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _format_str_two_value(val1, val2, rel):
|
|
107
|
+
"""format string"""
|
|
108
|
+
if rel == INC_NEITHER:
|
|
109
|
+
return "({}, {})".format(val1, val2)
|
|
110
|
+
if rel == INC_LEFT:
|
|
111
|
+
return "[{}, {})".format(val1, val2)
|
|
112
|
+
if rel == INC_RIGHT:
|
|
113
|
+
return "({}, {}]".format(val1, val2)
|
|
114
|
+
if rel == INC_BOTH:
|
|
115
|
+
return "[{}, {}]".format(val1, val2)
|
|
116
|
+
|
|
117
|
+
return ""
|
|
99
118
|
|
|
100
119
|
|
|
101
120
|
def _check_3d_int_or_tuple(arg_name, arg_value, prim_name, allow_five=False, ret_five=False,
|
|
@@ -106,71 +125,99 @@ def _check_3d_int_or_tuple(arg_name, arg_value, prim_name, allow_five=False, ret
|
|
|
106
125
|
|
|
107
126
|
def _raise_message(third_one_flag=False, three_input_flag=False):
|
|
108
127
|
if third_one_flag:
|
|
109
|
-
raise ValueError(
|
|
110
|
-
|
|
128
|
+
raise ValueError("For '{}', the depth of parameter '{}' must be 1, but got {}." \
|
|
129
|
+
.format(prim_name, arg_name, ret_value[-3]))
|
|
111
130
|
if three_input_flag:
|
|
112
|
-
raise ValueError(
|
|
113
|
-
|
|
114
|
-
raise ValueError(
|
|
115
|
-
|
|
131
|
+
raise ValueError("For '{}', the parameter '{}' must be an positive integer " \
|
|
132
|
+
"or a tuple of three positive integer, but got {}.".format(prim_name, arg_name, arg_value))
|
|
133
|
+
raise ValueError("For '{}', the parameter '{}' must be an positive integer " \
|
|
134
|
+
"or a tuple of three {}positive integer, but got {}" \
|
|
135
|
+
.format(prim_name, arg_name, 'or five ' if allow_five else '', arg_value))
|
|
116
136
|
|
|
117
137
|
def _get_return_value():
|
|
138
|
+
def _check():
|
|
139
|
+
if not isinstance(arg_value, int):
|
|
140
|
+
if len(arg_value) == 5:
|
|
141
|
+
if not allow_five:
|
|
142
|
+
_raise_message()
|
|
143
|
+
elif not len(arg_value) == 3:
|
|
144
|
+
_raise_message()
|
|
145
|
+
|
|
146
|
+
_check()
|
|
118
147
|
if isinstance(arg_value, int):
|
|
119
148
|
ret = (1, 1, arg_value, arg_value, arg_value) if ret_five else (arg_value, arg_value, arg_value)
|
|
120
149
|
elif len(arg_value) == 3:
|
|
121
150
|
ret = (1, 1, arg_value[0], arg_value[1], arg_value[2]) if ret_five else arg_value
|
|
122
|
-
|
|
123
|
-
if not allow_five:
|
|
124
|
-
_raise_message()
|
|
151
|
+
else: # case: len(arg_value) == 5
|
|
125
152
|
ret = arg_value if ret_five else (arg_value[2], arg_value[3], arg_value[4])
|
|
126
|
-
|
|
127
|
-
_raise_message()
|
|
153
|
+
|
|
128
154
|
return ret
|
|
129
155
|
|
|
130
|
-
|
|
156
|
+
def _check_value(ret_value):
|
|
157
|
+
for item in ret_value:
|
|
158
|
+
if isinstance(item, int) and not isinstance(item, bool):
|
|
159
|
+
if greater_zero and item > 0:
|
|
160
|
+
continue
|
|
161
|
+
if not greater_zero and item >= 0:
|
|
162
|
+
continue
|
|
163
|
+
_raise_message()
|
|
164
|
+
|
|
165
|
+
def _check_third_one(ret_value):
|
|
166
|
+
if third_one:
|
|
167
|
+
if ret_value[-3] != 1:
|
|
168
|
+
_raise_message(third_one_flag=third_one)
|
|
169
|
+
|
|
170
|
+
check_value_type(arg_name, arg_value, (int, tuple), prim_name)
|
|
131
171
|
if three_input and isinstance(arg_value, tuple):
|
|
132
172
|
if len(arg_value) != 3:
|
|
133
173
|
_raise_message(three_input_flag=three_input)
|
|
134
174
|
ret_value = _get_return_value()
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
if greater_zero and item > 0:
|
|
138
|
-
continue
|
|
139
|
-
if not greater_zero and item >= 0:
|
|
140
|
-
continue
|
|
141
|
-
_raise_message()
|
|
142
|
-
|
|
143
|
-
if third_one:
|
|
144
|
-
if ret_value[-3] != 1:
|
|
145
|
-
_raise_message(third_one_flag=third_one)
|
|
175
|
+
_check_value(ret_value)
|
|
176
|
+
_check_third_one(ret_value)
|
|
146
177
|
|
|
147
178
|
return tuple(ret_value)
|
|
148
179
|
|
|
149
180
|
|
|
150
|
-
def
|
|
181
|
+
def _check_dup(axes):
|
|
182
|
+
for item in axes:
|
|
183
|
+
count = 0
|
|
184
|
+
for item2 in axes:
|
|
185
|
+
if item == item2:
|
|
186
|
+
count += 1
|
|
187
|
+
|
|
188
|
+
if count > 1:
|
|
189
|
+
raise ValueError(f"The element of parameter 'axis' can not be duplicate, but got {axes}.")
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def _check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=None):
|
|
151
193
|
"""
|
|
152
194
|
Check argument integer.
|
|
153
195
|
|
|
154
196
|
Usage:
|
|
155
|
-
- arg_value =
|
|
197
|
+
- arg_value = _check_number(arg_value, 2, GT, int, "value", None)
|
|
156
198
|
"""
|
|
157
|
-
rel_fn = Rel.get_fns(rel)
|
|
158
199
|
prim_name = f"For \'{prim_name}\', the " if prim_name else 'The '
|
|
159
200
|
arg_name = f"\'{arg_name}\'" if arg_name else 'input value'
|
|
160
|
-
prim_info = f'{prim_name}' + f'{arg_name}'
|
|
161
|
-
if isinstance(arg_value, arg_type):
|
|
162
|
-
if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value):
|
|
163
|
-
raise ValueError(f"{prim_info} must be a legal value, but got '{arg_value}'.")
|
|
164
|
-
else:
|
|
165
|
-
raise TypeError(f"{prim_info} must be {arg_type.__name__}, but got '{type(arg_value).__name__}'")
|
|
166
201
|
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
202
|
+
def _check_param():
|
|
203
|
+
prim_info = f'{prim_name}' + f'{arg_name}'
|
|
204
|
+
if isinstance(arg_value, arg_type):
|
|
205
|
+
if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value):
|
|
206
|
+
raise ValueError(f"{prim_info} must be a legal value, but got '{arg_value}'.")
|
|
207
|
+
else:
|
|
208
|
+
raise TypeError(f"{prim_info} must be {arg_type.__name__}, but got '{type(arg_value).__name__}'")
|
|
209
|
+
|
|
210
|
+
type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool)
|
|
211
|
+
rel_ret = _check_binary_rel(arg_value, value, rel)
|
|
212
|
+
if type_mismatch or not rel_ret:
|
|
213
|
+
rel_str = _format_str_one_value(value, rel)
|
|
214
|
+
msg = f"{prim_info} must be {arg_type.__name__} and must {rel_str}, " \
|
|
215
|
+
f"but got '{arg_value}' with type '{type(arg_value).__name__}'."
|
|
216
|
+
if type_mismatch:
|
|
217
|
+
raise TypeError(msg)
|
|
218
|
+
raise ValueError(msg)
|
|
173
219
|
|
|
220
|
+
_check_param()
|
|
174
221
|
return arg_value
|
|
175
222
|
|
|
176
223
|
|
|
@@ -185,11 +232,16 @@ def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None):
|
|
|
185
232
|
"""
|
|
186
233
|
prim_name = f"For \'{prim_name}\', the" if prim_name else 'The'
|
|
187
234
|
arg_name = f"\'{arg_name}\'" if arg_name else 'input value'
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
235
|
+
|
|
236
|
+
def _check_param():
|
|
237
|
+
if isinstance(arg_value, arg_type) and not isinstance(arg_value, bool):
|
|
238
|
+
if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value):
|
|
239
|
+
raise ValueError(f"{prim_name} {arg_name} must be a legal float, but got '{arg_value}'.")
|
|
240
|
+
else:
|
|
241
|
+
raise TypeError("{} type of {} must be {}, but got '{}'".format(
|
|
242
|
+
prim_name, arg_name, arg_type.__name__, type(arg_value).__name__))
|
|
243
|
+
_check_param()
|
|
244
|
+
return arg_value
|
|
193
245
|
|
|
194
246
|
|
|
195
247
|
def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg_name=None, prim_name=None):
|
|
@@ -197,899 +249,940 @@ def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg
|
|
|
197
249
|
Method for checking whether an int value is in some range.
|
|
198
250
|
|
|
199
251
|
Usage:
|
|
200
|
-
- number = check_number_range(number, 0.0, 1.0,
|
|
201
|
-
- number = check_number_range(number, 0, 1,
|
|
252
|
+
- number = check_number_range(number, 0.0, 1.0, INC_NEITHER, "number", float) # number in [0.0, 1.0]
|
|
253
|
+
- number = check_number_range(number, 0, 1, INC_NEITHER, "number", int) # number in [0, 1]
|
|
202
254
|
"""
|
|
203
|
-
rel_fn = Rel.get_fns(rel)
|
|
204
255
|
prim_name = f"For \'{prim_name}\', the" if prim_name else 'The'
|
|
205
256
|
arg_name = f"\'{arg_name}\'" if arg_name else 'input value'
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
257
|
+
|
|
258
|
+
def _check_param():
|
|
259
|
+
type_mismatch = not isinstance(arg_value, (np.ndarray, np.generic, value_type)) or isinstance(arg_value, bool)
|
|
260
|
+
if type_mismatch:
|
|
261
|
+
raise TypeError("{} {} must be '{}', but got '{}'.".format(
|
|
262
|
+
prim_name, arg_name, value_type.__name__, type(arg_value).__name__))
|
|
263
|
+
|
|
264
|
+
if not _check_inc_rel(arg_value, lower_limit, upper_limit, rel):
|
|
265
|
+
rel_str = _format_str_two_value(lower_limit, upper_limit, rel)
|
|
266
|
+
raise ValueError("{} {} must be in range of {}, but got {} with type '{}'.".format(
|
|
267
|
+
prim_name, arg_name, rel_str, arg_value, type(arg_value).__name__))
|
|
268
|
+
_check_param()
|
|
214
269
|
return arg_value
|
|
215
270
|
|
|
216
271
|
|
|
217
|
-
|
|
218
|
-
|
|
272
|
+
def is_stub_tensor(tensor):
|
|
273
|
+
return hasattr(tensor, "stub")
|
|
274
|
+
|
|
219
275
|
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
|
|
276
|
+
def check(arg_name, arg_value, value_name, value, rel=EQ, prim_name=None, excp_cls=ValueError):
|
|
277
|
+
"""
|
|
278
|
+
Method for judging relation between two int values or list/tuple made up of ints.
|
|
279
|
+
This method is not suitable for judging relation between floats, since it does not consider float error.
|
|
280
|
+
"""
|
|
281
|
+
def _check():
|
|
282
|
+
if not _check_binary_rel(arg_value, value, rel):
|
|
283
|
+
rel_str = _format_str_one_value(f'{value_name}: {value}', rel)
|
|
229
284
|
msg_prefix = f'For \'{prim_name}\', the' if prim_name else "The"
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
285
|
+
msg_subject = f"{msg_prefix} \'{arg_name}\'" if " " not in arg_name else f"{msg_prefix} {arg_name}"
|
|
286
|
+
raise excp_cls(f'{msg_subject} should be {rel_str}, but got {arg_value}.')
|
|
287
|
+
_check()
|
|
288
|
+
return arg_value
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def check_int(arg_value, value, rel, arg_name=None, prim_name=None):
|
|
292
|
+
"""
|
|
293
|
+
Checks input integer value `arg_value` compare to `value`.
|
|
294
|
+
|
|
295
|
+
Usage:
|
|
296
|
+
- number = check_int(number, 0, GE, "number", None) # number >= 0
|
|
297
|
+
"""
|
|
298
|
+
return _check_number(arg_value, value, rel, int, arg_name, prim_name)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def check_is_int(arg_value, arg_name=None, prim_name=None):
|
|
302
|
+
"""
|
|
303
|
+
Checks input value is float type or not.
|
|
304
|
+
|
|
305
|
+
Usage:
|
|
306
|
+
- number = check_is_int(number, int)
|
|
307
|
+
- number = check_is_int(number, int, "bias")
|
|
308
|
+
- number = check_is_int(number, int, "bias", "bias_class")
|
|
309
|
+
"""
|
|
310
|
+
return check_is_number(arg_value, int, arg_name, prim_name)
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def check_equal_int(arg_value, value, arg_name=None, prim_name=None):
|
|
314
|
+
"""
|
|
315
|
+
Checks input integer value `arg_value` compare to `value`.
|
|
316
|
+
|
|
317
|
+
Usage:
|
|
318
|
+
- number = check_int(number, 0, GE, "number", None) # number >= 0
|
|
319
|
+
"""
|
|
320
|
+
return _check_number(arg_value, value, EQ, int, arg_name, prim_name)
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def check_positive_int(arg_value, arg_name=None, prim_name=None):
|
|
324
|
+
"""
|
|
325
|
+
Check argument is positive integer, which mean arg_value > 0.
|
|
326
|
+
|
|
327
|
+
Usage:
|
|
328
|
+
- number = check_positive_int(number)
|
|
329
|
+
- number = check_positive_int(number, "bias")
|
|
330
|
+
"""
|
|
331
|
+
return _check_number(arg_value, 0, GT, int, arg_name, prim_name)
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def check_positive_int_sequence(sequence, arg_name=None, prim_name=None):
|
|
335
|
+
"""
|
|
336
|
+
Check argument is positive int sequence, which mean all element > 0 in sequence.
|
|
337
|
+
|
|
338
|
+
Usage:
|
|
339
|
+
- sequence = check_positive_int_sequence(sequence)
|
|
340
|
+
- sequence = check_positive_int_sequence(sequence, "dims")
|
|
341
|
+
"""
|
|
342
|
+
for idx, element in enumerate(sequence):
|
|
343
|
+
arg_idx = '{}[{}]'.format(arg_name if arg_name else 'arg_name', idx)
|
|
344
|
+
_check_number(element, 0, GT, int, arg_idx, prim_name)
|
|
345
|
+
return sequence
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
def check_negative_int(arg_value, arg_name=None, prim_name=None):
|
|
349
|
+
"""
|
|
350
|
+
Check argument is negative integer, which mean arg_value < 0.
|
|
351
|
+
|
|
352
|
+
Usage:
|
|
353
|
+
- number = check_negative_int(number)
|
|
354
|
+
- number = check_negative_int(number, "bias")
|
|
355
|
+
"""
|
|
356
|
+
return _check_number(arg_value, 0, LT, int, arg_name, prim_name)
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def check_non_positive_int(arg_value, arg_name=None, prim_name=None):
|
|
360
|
+
"""
|
|
361
|
+
Check argument is non-negative integer, which mean arg_value <= 0.
|
|
362
|
+
|
|
363
|
+
Usage:
|
|
364
|
+
- number = check_non_positive_int(number)
|
|
365
|
+
- number = check_non_positive_int(number, "bias")
|
|
366
|
+
"""
|
|
367
|
+
return _check_number(arg_value, 0, LE, int, arg_name, prim_name)
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def check_non_negative_int(arg_value, arg_name=None, prim_name=None):
|
|
371
|
+
"""
|
|
372
|
+
Check argument is non-negative integer, which mean arg_value >= 0.
|
|
373
|
+
|
|
374
|
+
Usage:
|
|
375
|
+
- number = check_non_negative_int(number)
|
|
376
|
+
- number = check_non_negative_int(number, "bias")
|
|
377
|
+
"""
|
|
378
|
+
return _check_number(arg_value, 0, GE, int, arg_name, prim_name)
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def check_non_negative_int_sequence(sequence, arg_name=None, prim_name=None):
|
|
382
|
+
"""
|
|
383
|
+
Check argument is positive sequence, which mean all element >= 0 in sequence.
|
|
384
|
+
|
|
385
|
+
Usage:
|
|
386
|
+
- sequence = check_non_negative_int_sequence(sequence)
|
|
387
|
+
- sequence = check_non_negative_int_sequence(sequence, "dims")
|
|
388
|
+
"""
|
|
389
|
+
for idx, element in enumerate(sequence):
|
|
390
|
+
arg_idx = '{}[{}]'.format(arg_name if arg_name else 'arg_name', idx)
|
|
391
|
+
_check_number(element, 0, GE, int, arg_idx, prim_name)
|
|
392
|
+
return sequence
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def check_float(arg_value, value, rel, arg_name=None, prim_name=None):
|
|
396
|
+
"""
|
|
397
|
+
Checks input float value `arg_value` compare to `value`.
|
|
398
|
+
|
|
399
|
+
Usage:
|
|
400
|
+
- number = check_float(number, 0.0, GE, "number", None) # number >= 0
|
|
401
|
+
"""
|
|
402
|
+
return _check_number(arg_value, value, rel, float, arg_name, prim_name)
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
def check_is_float(arg_value, arg_name=None, prim_name=None):
|
|
406
|
+
"""
|
|
407
|
+
Checks input value is float type or not.
|
|
408
|
+
|
|
409
|
+
Usage:
|
|
410
|
+
- number = check_is_float(number)
|
|
411
|
+
- number = check_is_float(number, "bias")
|
|
412
|
+
- number = check_is_float(number, "bias", "bias_class")
|
|
413
|
+
"""
|
|
414
|
+
return check_is_number(arg_value, float, arg_name, prim_name)
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def check_positive_float(arg_value, arg_name=None, prim_name=None):
|
|
418
|
+
"""
|
|
419
|
+
Check argument is positive float, which mean arg_value > 0.
|
|
420
|
+
|
|
421
|
+
Usage:
|
|
422
|
+
- number = check_positive_float(number)
|
|
423
|
+
- number = check_positive_float(number, "bias")
|
|
424
|
+
- number = check_positive_float(number, "bias", "bias_class")
|
|
425
|
+
"""
|
|
426
|
+
return _check_number(arg_value, 0, GT, float, arg_name, prim_name)
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def check_positive_float_sequence(sequence, arg_name=None, prim_name=None):
|
|
430
|
+
"""
|
|
431
|
+
Check argument is positive sequence, which mean all element > 0 in sequence.
|
|
432
|
+
|
|
433
|
+
Usage:
|
|
434
|
+
- sequence = check_positive_float_sequence(sequence)
|
|
435
|
+
- sequence = check_positive_float_sequence(sequence, "dims")
|
|
436
|
+
"""
|
|
437
|
+
for idx, element in enumerate(sequence):
|
|
438
|
+
arg_idx = '{}[{}]'.format(arg_name if arg_name else 'arg_name', idx)
|
|
439
|
+
_check_number(element, 0, GT, float, arg_idx, prim_name)
|
|
440
|
+
return sequence
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def check_negative_float(arg_value, arg_name=None, prim_name=None):
|
|
444
|
+
"""
|
|
445
|
+
Check argument is negative float, which mean arg_value < 0.
|
|
446
|
+
|
|
447
|
+
Usage:
|
|
448
|
+
- number = check_negative_float(number)
|
|
449
|
+
- number = check_negative_float(number, "bias")
|
|
450
|
+
"""
|
|
451
|
+
return _check_number(arg_value, 0, LT, float, arg_name, prim_name)
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
def check_non_positive_float(arg_value, arg_name=None, prim_name=None):
|
|
455
|
+
"""
|
|
456
|
+
Check argument is non-negative float, which mean arg_value <= 0.
|
|
457
|
+
|
|
458
|
+
Usage:
|
|
459
|
+
- number = check_non_positive_float(number)
|
|
460
|
+
- number = check_non_positive_float(number, "bias")
|
|
461
|
+
"""
|
|
462
|
+
return _check_number(arg_value, 0, LE, float, arg_name, prim_name)
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
def check_non_negative_float(arg_value, arg_name=None, prim_name=None):
|
|
466
|
+
"""
|
|
467
|
+
Check argument is non-negative float, which mean arg_value >= 0.
|
|
468
|
+
|
|
469
|
+
Usage:
|
|
470
|
+
- number = check_non_negative_float(number)
|
|
471
|
+
- number = check_non_negative_float(number, "bias")
|
|
472
|
+
"""
|
|
473
|
+
return _check_number(arg_value, 0, GE, float, arg_name, prim_name)
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
def check_number(arg_name, arg_value, value, rel, prim_name):
|
|
477
|
+
"""Number value judgment."""
|
|
478
|
+
def _check():
|
|
479
|
+
if not _check_binary_rel(arg_value, value, rel):
|
|
480
|
+
rel_str = _format_str_one_value(value, rel)
|
|
481
|
+
raise ValueError(f'For \'{prim_name}\', the argument \'{arg_name}\' ' \
|
|
482
|
+
f'must {rel_str}, but got {arg_value}.')
|
|
483
|
+
_check()
|
|
484
|
+
return arg_value
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
def check_isinstance(arg_name, arg_value, classes):
|
|
488
|
+
"""Check arg isinstance of classes"""
|
|
489
|
+
def _check():
|
|
430
490
|
if not isinstance(arg_value, classes):
|
|
431
491
|
raise ValueError(f'The parameter \'{arg_name}\' must be isinstance of {classes}, but got {arg_value}.')
|
|
432
|
-
|
|
492
|
+
_check()
|
|
493
|
+
return arg_value
|
|
433
494
|
|
|
434
|
-
@staticmethod
|
|
435
|
-
def check_bool(arg_value, arg_name=None, prim_name=None):
|
|
436
|
-
"""
|
|
437
|
-
Check argument is instance of bool.
|
|
438
495
|
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
496
|
+
def check_bool(arg_value, arg_name=None, prim_name=None):
|
|
497
|
+
"""
|
|
498
|
+
Check argument is instance of bool.
|
|
499
|
+
|
|
500
|
+
Usage:
|
|
501
|
+
- has_bias = check_bool(has_bias)
|
|
502
|
+
- has_bias = check_bool(has_bias, "has_bias")
|
|
503
|
+
"""
|
|
504
|
+
prim_name = f"For '{prim_name}', the" if prim_name else 'The'
|
|
505
|
+
arg_name = f"'{arg_name}'" if arg_name else 'input value'
|
|
506
|
+
|
|
507
|
+
def _check():
|
|
443
508
|
if not isinstance(arg_value, bool):
|
|
444
|
-
prim_name = f"For '{prim_name}', the" if prim_name else 'The'
|
|
445
|
-
arg_name = f"'{arg_name}'" if arg_name else 'input value'
|
|
446
509
|
raise TypeError(f"{prim_name} {arg_name} must be a bool, but got {type(arg_value).__name__}.")
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
@staticmethod
|
|
450
|
-
def check_int_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
|
|
451
|
-
"""
|
|
452
|
-
Method for checking whether input value is in int range.
|
|
453
|
-
|
|
454
|
-
Usage:
|
|
455
|
-
- number = check_int_range(number, 0, 1, Rel.INC_NEITHER) # number in [0, 1]
|
|
456
|
-
- number = check_int_range(number, 0, 1, Rel.INC_NEITHER, "number") # number in [0, 1]
|
|
457
|
-
"""
|
|
458
|
-
return check_number_range(arg_value, lower_limit, upper_limit, rel, int, arg_name, prim_name)
|
|
459
|
-
|
|
460
|
-
@staticmethod
|
|
461
|
-
def check_float_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
|
|
462
|
-
"""
|
|
463
|
-
Method for checking whether input value is in float range.
|
|
464
|
-
|
|
465
|
-
Usage:
|
|
466
|
-
- number = check_float_range(number, 0.0, 1.0, Rel.INC_NEITHER) # number in [0.0, 1.0]
|
|
467
|
-
- number = check_float_range(number, 0.0, 1.0, Rel.INC_NEITHER, "number") # number in [0.0, 1.0]
|
|
468
|
-
"""
|
|
469
|
-
return check_number_range(arg_value, lower_limit, upper_limit, rel, float, arg_name, prim_name)
|
|
470
|
-
|
|
471
|
-
@staticmethod
|
|
472
|
-
def check_string(arg_value, valid_values, arg_name=None, prim_name=None):
|
|
473
|
-
"""
|
|
474
|
-
Check whether string is in some value list.
|
|
475
|
-
|
|
476
|
-
Usage:
|
|
477
|
-
- method = check_string(method, ["string1", "string2", "string3"], "method")
|
|
478
|
-
"""
|
|
479
|
-
if isinstance(arg_value, str) and arg_value in valid_values:
|
|
480
|
-
return arg_value
|
|
481
|
-
arg_name = arg_name if arg_name else "parameter"
|
|
482
|
-
msg_prefix = f'For \'{prim_name}\', the' if prim_name else "The"
|
|
483
|
-
raise ValueError(f"{msg_prefix} '{arg_name}' must be str and must be in '{valid_values}',"
|
|
484
|
-
f" but got '{arg_value}'.")
|
|
485
|
-
|
|
486
|
-
@staticmethod
|
|
487
|
-
def check_str_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
|
|
488
|
-
if reg is None:
|
|
489
|
-
# Named string regular expression
|
|
490
|
-
reg = r"^\w+[0-9a-zA-Z\_\.]*$"
|
|
491
|
-
if re.match(reg, target, flag) is None:
|
|
492
|
-
prim_name = f"For '{prim_name}', the" if prim_name else "The"
|
|
493
|
-
raise ValueError("{} '{}' is illegal, it must be match regular'{}' by flags'{}.'".format(
|
|
494
|
-
prim_name, target, reg, flag))
|
|
495
|
-
return True
|
|
510
|
+
_check()
|
|
511
|
+
return arg_value
|
|
496
512
|
|
|
497
|
-
@staticmethod
|
|
498
|
-
def check_file_name_by_regular(target, reg=None, prim_name=None):
|
|
499
|
-
"""Check whether file name is legitimate."""
|
|
500
|
-
if not isinstance(target, str):
|
|
501
|
-
prim_name = f"For '{prim_name}', the" if prim_name else "The"
|
|
502
|
-
raise TypeError("{} '{}' must be string, but got {}.".format(prim_name, target, type(target)))
|
|
503
|
-
if target.endswith("\\") or target.endswith("/"):
|
|
504
|
-
prim_name = f"For '{prim_name}', the" if prim_name else "The"
|
|
505
|
-
raise ValueError(f"{prim_name} '{target}' cannot be a directory path.")
|
|
506
|
-
if reg is None:
|
|
507
|
-
reg = r"^[0-9a-zA-Z\_\-\.\:\/\\]+$"
|
|
508
|
-
if re.match(reg, target) is None:
|
|
509
|
-
prim_name = f"For '{prim_name}', the" if prim_name else "The"
|
|
510
|
-
raise ValueError("{} '{}' is illegal, it must be match regular '{}'.".format(
|
|
511
|
-
prim_name, target, reg))
|
|
512
513
|
|
|
513
|
-
|
|
514
|
+
def check_int_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
|
|
515
|
+
"""
|
|
516
|
+
Method for checking whether input value is in int range.
|
|
517
|
+
|
|
518
|
+
Usage:
|
|
519
|
+
- number = check_int_range(number, 0, 1, INC_NEITHER) # number in [0, 1]
|
|
520
|
+
- number = check_int_range(number, 0, 1, INC_NEITHER, "number") # number in [0, 1]
|
|
521
|
+
"""
|
|
522
|
+
return check_number_range(arg_value, lower_limit, upper_limit, rel, int, arg_name, prim_name)
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
def check_float_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
|
|
526
|
+
"""
|
|
527
|
+
Method for checking whether input value is in float range.
|
|
528
|
+
|
|
529
|
+
Usage:
|
|
530
|
+
- number = check_float_range(number, 0.0, 1.0, INC_NEITHER) # number in [0.0, 1.0]
|
|
531
|
+
- number = check_float_range(number, 0.0, 1.0, INC_NEITHER, "number") # number in [0.0, 1.0]
|
|
532
|
+
"""
|
|
533
|
+
return check_number_range(arg_value, lower_limit, upper_limit, rel, float, arg_name, prim_name)
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
def check_string(arg_value, valid_values, arg_name=None, prim_name=None):
|
|
537
|
+
"""
|
|
538
|
+
Check whether string is in some value list.
|
|
539
|
+
|
|
540
|
+
Usage:
|
|
541
|
+
- method = check_string(method, ["string1", "string2", "string3"], "method")
|
|
542
|
+
"""
|
|
543
|
+
arg_name = arg_name if arg_name else "parameter"
|
|
544
|
+
msg_prefix = f'For \'{prim_name}\', the' if prim_name else "The"
|
|
545
|
+
|
|
546
|
+
def _check():
|
|
547
|
+
if not (isinstance(arg_value, str) and arg_value in valid_values):
|
|
548
|
+
raise ValueError(f"{msg_prefix} '{arg_name}' must be str and must be in '{valid_values}'," \
|
|
549
|
+
f" but got '{arg_value}'.")
|
|
550
|
+
_check()
|
|
551
|
+
return arg_value
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
def check_str_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
|
|
555
|
+
if reg is None:
|
|
556
|
+
# Named string regular expression
|
|
557
|
+
reg = r"^\w+[0-9a-zA-Z\_\.]*$"
|
|
558
|
+
if re.match(reg, target, flag) is None:
|
|
559
|
+
prim_name = f"For '{prim_name}', the" if prim_name else "The"
|
|
560
|
+
raise ValueError("{} '{}' is illegal, it must be match regular'{}' by flags'{}.'".format(
|
|
561
|
+
prim_name, target, reg, flag))
|
|
562
|
+
return True
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
def check_file_name_by_regular(target, reg=None, prim_name=None):
|
|
566
|
+
"""Check whether file name is legitimate."""
|
|
567
|
+
if not isinstance(target, str):
|
|
568
|
+
prim_name = f"For '{prim_name}', the" if prim_name else "The"
|
|
569
|
+
raise TypeError("{} '{}' must be string, but got {}.".format(prim_name, target, type(target)))
|
|
570
|
+
if target.endswith("\\") or target.endswith("/"):
|
|
571
|
+
prim_name = f"For '{prim_name}', the" if prim_name else "The"
|
|
572
|
+
raise ValueError(f"{prim_name} '{target}' cannot be a directory path.")
|
|
573
|
+
if reg is None:
|
|
574
|
+
reg = r"^[0-9a-zA-Z@\_\-\.\:\/\\]+$"
|
|
575
|
+
if re.match(reg, target) is None:
|
|
576
|
+
prim_name = f"For '{prim_name}', the" if prim_name else "The"
|
|
577
|
+
raise ValueError("{} '{}' is illegal, it must be match regular '{}'.".format(
|
|
578
|
+
prim_name, target, reg))
|
|
579
|
+
|
|
580
|
+
return True
|
|
581
|
+
|
|
514
582
|
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
if
|
|
531
|
-
if mstype._issubclass_(type_, template_type): # pylint: disable=W0212
|
|
532
|
-
hit = True
|
|
533
|
-
break
|
|
534
|
-
elif type_ is template_type:
|
|
583
|
+
def check_pad_value_by_mode(pad_mode, padding, prim_name):
|
|
584
|
+
"""Validates value of padding according to pad_mode"""
|
|
585
|
+
if pad_mode != 'pad' and padding != 0:
|
|
586
|
+
raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}'," \
|
|
587
|
+
f" but got {padding}.")
|
|
588
|
+
return padding
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
def check_subclass(arg_name, type_, template_types, prim_name, addition_error_info=None):
|
|
592
|
+
"""Checks whether some type is subclass of another type"""
|
|
593
|
+
if not isinstance(template_types, Iterable):
|
|
594
|
+
template_types = (template_types,)
|
|
595
|
+
hit = False
|
|
596
|
+
for template_type in template_types:
|
|
597
|
+
if isinstance(template_type, mstype.Type):
|
|
598
|
+
if mstype._issubclass_(type_, template_type): # pylint: disable=W0212
|
|
535
599
|
hit = True
|
|
536
600
|
break
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
601
|
+
elif type_ is template_type:
|
|
602
|
+
hit = True
|
|
603
|
+
break
|
|
604
|
+
if not hit:
|
|
605
|
+
if addition_error_info is None:
|
|
606
|
+
addition_error_info = ''
|
|
607
|
+
else:
|
|
608
|
+
addition_error_info = ' ' + addition_error_info
|
|
609
|
+
type_str = (f"type '{type(type_).__name__}'" if isinstance(type_, (tuple, list)) else str(type_))
|
|
610
|
+
raise TypeError(f"For '{prim_name}', the element of '{arg_name}'" \
|
|
611
|
+
f" must be {'one of ' if len(template_types) > 1 else ''}" \
|
|
612
|
+
f"{', '.join((str(x) for x in template_types))}, but got {type_str}" \
|
|
613
|
+
f"{addition_error_info}.The supported data types depend on the hardware that" \
|
|
614
|
+
f" executes the operator, for more details, please refer to the MindSpore official " \
|
|
615
|
+
f"website to get more information about the data type.")
|
|
616
|
+
|
|
617
|
+
|
|
618
|
+
def check_valid_input(arg_name, arg_value, prim_name):
|
|
619
|
+
"""Checks valid value."""
|
|
620
|
+
def _check():
|
|
553
621
|
if arg_value is None:
|
|
554
|
-
raise ValueError(f"For \'{prim_name}\', the argument '{arg_name}'
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
else
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
num_types = len(valid_types)
|
|
666
|
-
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
667
|
-
raise TypeError(f"{msg_prefix} '{arg_name}' should be {'one of ' if num_types > 1 else ''}"
|
|
668
|
-
f"{type_names if num_types > 1 else type_names[0]}, "
|
|
669
|
-
f"but got '{arg_type.__name__ if hasattr(arg_type, '__name__') else repr(arg_type)}'.")
|
|
670
|
-
|
|
671
|
-
if isinstance(arg_type, type(mstype.tensor)):
|
|
672
|
-
arg_type = arg_type.element_type()
|
|
673
|
-
if arg_type not in valid_types:
|
|
674
|
-
raise_error_msg()
|
|
675
|
-
return arg_type
|
|
676
|
-
|
|
677
|
-
@staticmethod
|
|
678
|
-
def check_reduce_shape(ori_shape, shape, axis, prim_name, arg_name1, arg_name2):
|
|
679
|
-
"""Checks whether shape is ori_shape reduced on axis"""
|
|
680
|
-
axis_origin = axis
|
|
681
|
-
axis = axis if isinstance(axis, Iterable) else (axis,)
|
|
682
|
-
exp_shape = [ori_shape[i] for i in range(len(ori_shape)) if i not in axis]
|
|
683
|
-
if list(shape) != exp_shape:
|
|
684
|
-
raise ValueError(f"For '{prim_name}', "
|
|
685
|
-
f"the shape of parameter '{arg_name1}' reduce on 'axis': {axis_origin} must "
|
|
686
|
-
f"be equal to the shape of '{arg_name2}': {shape}, but got {ori_shape}.")
|
|
687
|
-
|
|
688
|
-
@staticmethod
|
|
689
|
-
def check_astype_dtype(dtype):
|
|
690
|
-
"""Check whether dtype is a valid input, and convert to mstype"""
|
|
691
|
-
all_types = mstype.__dtype__ + ["int", "float", "bool"]
|
|
692
|
-
if isinstance(dtype, str):
|
|
693
|
-
if dtype.lower() not in all_types:
|
|
694
|
-
raise TypeError(f"For Tensor.astype, the input type must be one of {all_types}, but got '{dtype}'.")
|
|
695
|
-
dtype = mstype.pytype_to_dtype(np.dtype(dtype.lower()))
|
|
696
|
-
elif isinstance(dtype, type):
|
|
697
|
-
dtype = mstype.pytype_to_dtype(dtype)
|
|
698
|
-
elif not dtype in mstype.number_type + (mstype.bool_,):
|
|
699
|
-
raise TypeError(f"For Tensor.astype, the input type must be one of {mstype.number_type + (mstype.bool_,)},"
|
|
700
|
-
f" but got '{dtype}'.")
|
|
701
|
-
return dtype
|
|
702
|
-
|
|
703
|
-
@staticmethod
|
|
704
|
-
def check_transpose_axis(axes, ndim):
|
|
705
|
-
"""Check the axis argument for tensor.transpose"""
|
|
706
|
-
if not axes or (len(axes) == 1 and axes[0] is None):
|
|
707
|
-
return tuple(range(ndim-1, -1, -1))
|
|
708
|
-
|
|
709
|
-
if len(axes) == 1:
|
|
710
|
-
perm = axes[0]
|
|
711
|
-
# if only one argument provided, it must be tuple or list
|
|
712
|
-
if isinstance(perm, list):
|
|
713
|
-
perm = tuple(perm)
|
|
714
|
-
else:
|
|
715
|
-
if not isinstance(perm, tuple):
|
|
716
|
-
raise TypeError(f"For Tensor.transpose, the parameter 'axes' must be a tuple/list, "
|
|
717
|
-
f"or series of integer, but got {type(axes[0])}")
|
|
718
|
-
return perm
|
|
622
|
+
raise ValueError(f"For \'{prim_name}\', the argument '{arg_name}'" \
|
|
623
|
+
f"can not be None, but got {arg_value}.")
|
|
624
|
+
_check()
|
|
625
|
+
return arg_value
|
|
626
|
+
|
|
627
|
+
|
|
628
|
+
def check_types_same_and_valid(args, valid_values, prim_name):
|
|
629
|
+
"""Checks whether the types of inputs are the same and valid."""
|
|
630
|
+
|
|
631
|
+
def _check_type_valid(arg):
|
|
632
|
+
arg_key, arg_val = arg
|
|
633
|
+
elem_type = arg_val
|
|
634
|
+
check_subclass(arg_key, elem_type, valid_values, prim_name)
|
|
635
|
+
return (arg_key, elem_type)
|
|
636
|
+
|
|
637
|
+
def _check_types_same(arg1, arg2):
|
|
638
|
+
arg1_name, arg1_type = arg1
|
|
639
|
+
arg2_name, arg2_type = arg2
|
|
640
|
+
if arg1_type != arg2_type:
|
|
641
|
+
raise TypeError(f"For '{prim_name}', the type of '{arg2_name}' should be same as '{arg1_name}'," \
|
|
642
|
+
f" but got '{arg1_name}' with type {arg1_type}" \
|
|
643
|
+
f" and '{arg2_name}' with type {arg2_type}.")
|
|
644
|
+
return arg1
|
|
645
|
+
|
|
646
|
+
elem_types = map(_check_type_valid, args.items())
|
|
647
|
+
reduce(_check_types_same, elem_types)
|
|
648
|
+
|
|
649
|
+
|
|
650
|
+
def check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim_name):
|
|
651
|
+
"""Checks whether the element types of input tensors are the same and valid."""
|
|
652
|
+
valid_dtypes = valid_dtypes if isinstance(valid_dtypes, Iterable) else [valid_dtypes]
|
|
653
|
+
tensor_types = [mstype.tensor_type(t) for t in valid_dtypes]
|
|
654
|
+
check_types_same_and_valid(args, tensor_types, prim_name)
|
|
655
|
+
|
|
656
|
+
|
|
657
|
+
def check_tensor_dtype_valid(arg_name, arg_type, valid_dtypes, prim_name):
|
|
658
|
+
"""Checks whether the element types of input tensors are valid."""
|
|
659
|
+
valid_dtypes = valid_dtypes if isinstance(valid_dtypes, Iterable) else [valid_dtypes]
|
|
660
|
+
tensor_types = [mstype.tensor_type(t) for t in valid_dtypes]
|
|
661
|
+
check_subclass(arg_name, arg_type, tensor_types, prim_name)
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
def check_scalar_or_tensor_types_same(args, valid_values, prim_name, allow_mix=False):
|
|
665
|
+
"""
|
|
666
|
+
Checks whether the types of inputs are the same. If the input args are tensors, checks their element types.
|
|
667
|
+
If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised.
|
|
668
|
+
"""
|
|
669
|
+
|
|
670
|
+
def _check_argument_type(arg):
|
|
671
|
+
arg_key, arg_val = arg
|
|
672
|
+
if isinstance(arg_val, type(mstype.tensor)):
|
|
673
|
+
arg_val = arg_val.element_type()
|
|
674
|
+
if arg_val not in valid_values:
|
|
675
|
+
raise TypeError(f'For \'{prim_name}\', the type of \'{arg_key}\' must be in {valid_values},' \
|
|
676
|
+
f' but got {arg_val}.')
|
|
677
|
+
return arg
|
|
678
|
+
|
|
679
|
+
def _check_types_same(arg1, arg2):
|
|
680
|
+
arg1_name, arg1_type = arg1
|
|
681
|
+
arg2_name, arg2_type = arg2
|
|
682
|
+
except_flag = False
|
|
683
|
+
if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)):
|
|
684
|
+
arg1_type = arg1_type.element_type()
|
|
685
|
+
arg2_type = arg2_type.element_type()
|
|
686
|
+
elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))):
|
|
687
|
+
pass
|
|
688
|
+
elif allow_mix:
|
|
689
|
+
arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type
|
|
690
|
+
arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type
|
|
691
|
+
else:
|
|
692
|
+
except_flag = True
|
|
693
|
+
|
|
694
|
+
if except_flag or arg1_type != arg2_type:
|
|
695
|
+
raise TypeError(f"For '{prim_name}', the type of '{arg2_name}' must be same as '{arg1_name}'," \
|
|
696
|
+
f" but got '{arg1_name}' with type {arg1_type}" \
|
|
697
|
+
f" and '{arg2_name}' with type {arg2_type}.")
|
|
698
|
+
return arg1
|
|
699
|
+
|
|
700
|
+
args_map = map(_check_argument_type, args.items())
|
|
701
|
+
reduce(_check_types_same, args_map)
|
|
702
|
+
|
|
703
|
+
|
|
704
|
+
def check_value_type(arg_name, arg_value, valid_types, prim_name=None):
|
|
705
|
+
"""Checks whether a value is instance of some types."""
|
|
706
|
+
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
|
|
707
|
+
|
|
708
|
+
def raise_error_msg(cond, arg_value):
|
|
709
|
+
"""func for raising error message when check failed"""
|
|
710
|
+
if not cond:
|
|
711
|
+
return
|
|
712
|
+
type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types]
|
|
713
|
+
num_types = len(valid_types)
|
|
714
|
+
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
715
|
+
raise TypeError(f'{msg_prefix} type of \'{arg_name}\' should be {"one of " if num_types > 1 else ""}' \
|
|
716
|
+
f'\'{type_names if num_types > 1 else type_names[0]}\', ' \
|
|
717
|
+
f'but got type \'{type(arg_value).__name__}\'.')
|
|
718
|
+
|
|
719
|
+
# Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
|
|
720
|
+
# `check_value_type('x', True, [bool, int])` will check pass
|
|
721
|
+
cond = isinstance(arg_value, bool) and bool not in tuple(valid_types)
|
|
722
|
+
raise_error_msg(cond, arg_value)
|
|
723
|
+
if isinstance(arg_value, float) and float not in tuple(valid_types):
|
|
724
|
+
arg_value = round(arg_value, 6)
|
|
725
|
+
cond = not isinstance(arg_value, tuple(valid_types))
|
|
726
|
+
raise_error_msg(cond, arg_value)
|
|
727
|
+
return arg_value
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
def check_type_name(arg_name, arg_type, valid_types, prim_name):
|
|
731
|
+
"""Checks whether a type in some specified types"""
|
|
732
|
+
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
|
|
719
733
|
|
|
734
|
+
def raise_error_msg(cond, arg_type):
|
|
735
|
+
"""func for raising error message when check failed"""
|
|
736
|
+
if not cond:
|
|
737
|
+
return
|
|
738
|
+
type_names = [t.__name__ if hasattr(t, '__name__') else t for t in valid_types]
|
|
739
|
+
num_types = len(valid_types)
|
|
740
|
+
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
741
|
+
raise TypeError(f"{msg_prefix} '{arg_name}' should be {'one of ' if num_types > 1 else ''}" \
|
|
742
|
+
f"{type_names if num_types > 1 else type_names[0]}, " \
|
|
743
|
+
f"but got '{arg_type.__name__ if hasattr(arg_type, '__name__') else repr(arg_type)}'.")
|
|
744
|
+
|
|
745
|
+
if isinstance(arg_type, type(mstype.tensor)):
|
|
746
|
+
arg_type = arg_type.element_type()
|
|
747
|
+
cond = arg_type not in valid_types
|
|
748
|
+
raise_error_msg(cond, arg_type)
|
|
749
|
+
return arg_type
|
|
750
|
+
|
|
751
|
+
|
|
752
|
+
def check_reduce_shape(ori_shape, shape, axis, prim_name, arg_name1, arg_name2):
|
|
753
|
+
"""Checks whether shape is ori_shape reduced on axis"""
|
|
754
|
+
axis_origin = axis
|
|
755
|
+
axis = axis if isinstance(axis, Iterable) else (axis,)
|
|
756
|
+
exp_shape = [ori_shape[i] for i in range(len(ori_shape)) if i not in axis]
|
|
757
|
+
if list(shape) != exp_shape:
|
|
758
|
+
raise ValueError(f"For '{prim_name}', " \
|
|
759
|
+
f"the shape of parameter '{arg_name1}' reduce on 'axis': {axis_origin} must " \
|
|
760
|
+
f"be equal to the shape of '{arg_name2}': {shape}, but got {ori_shape}.")
|
|
761
|
+
|
|
762
|
+
|
|
763
|
+
def check_astype_dtype(dtype):
|
|
764
|
+
"""Check whether dtype is a valid input, and convert to mstype"""
|
|
765
|
+
all_types = mstype.__dtype__ + ["int", "float", "bool"]
|
|
766
|
+
if isinstance(dtype, str):
|
|
767
|
+
if dtype.lower() not in all_types:
|
|
768
|
+
raise TypeError(f"For Tensor.astype, the input type must be one of {all_types}, but got '{dtype}'.")
|
|
769
|
+
dtype = mstype.pytype_to_dtype(np.dtype(dtype.lower()))
|
|
770
|
+
elif isinstance(dtype, type):
|
|
771
|
+
dtype = mstype.pytype_to_dtype(dtype)
|
|
772
|
+
elif not dtype in mstype.number_type + (mstype.bool_,):
|
|
773
|
+
raise TypeError(f"For Tensor.astype, the input type must be one of {mstype.number_type + (mstype.bool_,)}," \
|
|
774
|
+
f" but got '{dtype}'.")
|
|
775
|
+
return dtype
|
|
776
|
+
|
|
777
|
+
|
|
778
|
+
def check_transpose_axis(axes, ndim):
|
|
779
|
+
"""Check the axis argument for tensor.transpose"""
|
|
780
|
+
def _check_dim():
|
|
720
781
|
# if multiple arguments provided, it must be `ndim` number of ints
|
|
721
782
|
if len(axes) != ndim:
|
|
722
|
-
raise ValueError(f"For Tensor.transpose, the number of axes must be equal to the dimension of Tensor, "
|
|
783
|
+
raise ValueError(f"For Tensor.transpose, the number of axes must be equal to the dimension of Tensor, " \
|
|
723
784
|
f"but got {len(axes)} in the number of axes.")
|
|
724
|
-
return axes
|
|
725
785
|
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
if isinstance(
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
@staticmethod
|
|
747
|
-
def check_flatten_order(order):
|
|
748
|
-
"""Check flatten function input order"""
|
|
749
|
-
if not isinstance(order, str):
|
|
750
|
-
raise TypeError(f"For Tensor.flatten, the parameter 'order' must be a string, but got {type(order)}")
|
|
751
|
-
if order not in ('C', 'F'):
|
|
752
|
-
raise ValueError(f"For Tensor.flatten, the parameter 'order' must be 'C' or 'F', but got '{order}'")
|
|
753
|
-
return order
|
|
754
|
-
|
|
755
|
-
@staticmethod
|
|
756
|
-
def check_swapaxes_axis(axes, ndim):
|
|
757
|
-
"""Check all the axes argument for tensor.swapaxes"""
|
|
758
|
-
if isinstance(axes, int):
|
|
759
|
-
Validator.check_axis_in_range(axes, ndim)
|
|
760
|
-
return axes % ndim
|
|
761
|
-
if isinstance(axes, (tuple, list)):
|
|
762
|
-
for axis in axes:
|
|
763
|
-
if not isinstance(axis, int):
|
|
764
|
-
raise TypeError(f"For Tensor.swapaxes, the axis argument must be integer, but got {type(axis)}.")
|
|
765
|
-
Validator.check_axis_in_range(axis, ndim)
|
|
766
|
-
axes = tuple(map(lambda x: x % ndim, axes))
|
|
767
|
-
return axes
|
|
768
|
-
raise TypeError(f"For Tensor.swapaxes, the argument 'axes' must be integer, list or tuple for check, "
|
|
769
|
-
f"but got {type(axes)}.")
|
|
770
|
-
|
|
771
|
-
@staticmethod
|
|
772
|
-
def prepare_shape_for_squeeze(shape, axes):
|
|
773
|
-
"""
|
|
774
|
-
Creates the squeezed new shape based on the tensor and given axes.
|
|
775
|
-
|
|
776
|
-
Args:
|
|
777
|
-
shape (tuple): the shape of the tensor
|
|
778
|
-
axes Union[int, tuple(int), list(int)]: the axes with dimensions need to
|
|
779
|
-
be squeezed.
|
|
780
|
-
|
|
781
|
-
Returns:
|
|
782
|
-
new_shape(tuple): the shape with dimensions squeezed.
|
|
783
|
-
"""
|
|
784
|
-
new_shape = []
|
|
785
|
-
ndim = len(shape)
|
|
786
|
-
|
|
787
|
-
# Convert to set
|
|
788
|
-
if isinstance(axes, int):
|
|
789
|
-
if axes >= ndim or axes < -ndim:
|
|
790
|
-
raise ValueError(f"For Tensor.squeeze, "
|
|
791
|
-
f"the 'axis' must be in the range of [-{ndim}, {ndim}), but got {axes}.")
|
|
792
|
-
axes = {axes}
|
|
793
|
-
|
|
794
|
-
elif isinstance(axes, (list, tuple)):
|
|
795
|
-
for axis in axes:
|
|
796
|
-
if axis >= ndim or axis < -ndim:
|
|
797
|
-
raise ValueError(f"For Tensor.squeeze, "
|
|
798
|
-
f"the 'axis' must be in the range of [-{ndim}, {ndim}), but got {axis}.")
|
|
799
|
-
axes = set(axes)
|
|
786
|
+
if not axes or (len(axes) == 1 and axes[0] is None):
|
|
787
|
+
return tuple(range(ndim-1, -1, -1))
|
|
788
|
+
|
|
789
|
+
if len(axes) == 1:
|
|
790
|
+
perm = axes[0]
|
|
791
|
+
# if only one argument provided, it must be tuple or list
|
|
792
|
+
if isinstance(perm, list):
|
|
793
|
+
perm = tuple(perm)
|
|
794
|
+
else:
|
|
795
|
+
if not isinstance(perm, tuple):
|
|
796
|
+
raise TypeError(f"For Tensor.transpose, the parameter 'axes' must be a tuple/list, " \
|
|
797
|
+
f"or series of integer, but got {type(axes[0])}")
|
|
798
|
+
return perm
|
|
799
|
+
|
|
800
|
+
_check_dim()
|
|
801
|
+
return axes
|
|
802
|
+
|
|
803
|
+
|
|
804
|
+
def check_reshape_shp(shp):
|
|
805
|
+
"""Check the shape argument for tensor.reshape"""
|
|
800
806
|
|
|
807
|
+
if len(shp) == 1:
|
|
808
|
+
new_shape = shp[0]
|
|
809
|
+
# if only one argument provided, it must be int, tuple or list
|
|
810
|
+
if isinstance(new_shape, int):
|
|
811
|
+
return shp
|
|
812
|
+
if isinstance(new_shape, list):
|
|
813
|
+
new_shape = tuple(new_shape)
|
|
801
814
|
else:
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
815
|
+
if not isinstance(new_shape, tuple):
|
|
816
|
+
raise TypeError(
|
|
817
|
+
f"For Tensor.reshape, the parameter 'shape' must be an integer, or tuple/list, " \
|
|
818
|
+
f"or series of integer, but got {type(shp[0])}")
|
|
819
|
+
return new_shape
|
|
820
|
+
|
|
821
|
+
return shp
|
|
822
|
+
|
|
823
|
+
|
|
824
|
+
def check_flatten_order(order):
|
|
825
|
+
"""Check flatten function input order"""
|
|
826
|
+
if not isinstance(order, str):
|
|
827
|
+
raise TypeError(f"For Tensor.flatten, the parameter 'order' must be a string, but got {type(order)}")
|
|
828
|
+
if order not in ('C', 'F'):
|
|
829
|
+
raise ValueError(f"For Tensor.flatten, the parameter 'order' must be 'C' or 'F', but got '{order}'")
|
|
830
|
+
|
|
831
|
+
|
|
832
|
+
def check_swapaxes_axis(axes, ndim):
|
|
833
|
+
"""Check all the axes argument for ops.swapaxes"""
|
|
834
|
+
if isinstance(axes, int):
|
|
835
|
+
return check_axis_in_range(axes, ndim)
|
|
836
|
+
if isinstance(axes, (tuple, list)):
|
|
837
|
+
for axis in axes:
|
|
838
|
+
if not isinstance(axis, int):
|
|
839
|
+
raise TypeError(f"For ops.swapaxes, the axis argument must be integer, but got {type(axis)}.")
|
|
840
|
+
check_axis_in_range(axis, ndim)
|
|
841
|
+
tmp = ()
|
|
842
|
+
for x in axes:
|
|
843
|
+
tmp = tmp + ((x + ndim) % ndim,)
|
|
844
|
+
return tmp
|
|
845
|
+
raise TypeError(f"For ops.swapaxes, the argument 'axes' must be integer, list or tuple for check, " \
|
|
846
|
+
f"but got {type(axes)}.")
|
|
847
|
+
|
|
848
|
+
|
|
849
|
+
def prepare_shape_for_squeeze(shape, axes):
|
|
850
|
+
"""
|
|
851
|
+
Creates the squeezed new shape based on the tensor and given axes.
|
|
852
|
+
|
|
853
|
+
Args:
|
|
854
|
+
shape (tuple): the shape of the tensor
|
|
855
|
+
axes Union[int, tuple(int), list(int)]: the axes with dimensions need to
|
|
856
|
+
be squeezed.
|
|
857
|
+
|
|
858
|
+
Returns:
|
|
859
|
+
new_shape(tuple): the shape with dimensions squeezed.
|
|
860
|
+
"""
|
|
861
|
+
new_shape = ()
|
|
862
|
+
ndim = len(shape)
|
|
863
|
+
|
|
864
|
+
def _check(axes, ndim):
|
|
865
|
+
if axes >= ndim or axes < -ndim:
|
|
866
|
+
raise ValueError("For Tensor.squeeze, the 'axis' must be in the range of [-{0}, {0}), but got {1}." \
|
|
867
|
+
.format(ndim, axes))
|
|
868
|
+
|
|
869
|
+
def _check_for(axes, ndim):
|
|
870
|
+
for axis in axes:
|
|
871
|
+
_check(axis, ndim)
|
|
872
|
+
|
|
873
|
+
if isinstance(axes, int):
|
|
874
|
+
_check(axes, ndim)
|
|
875
|
+
axes = (axes,)
|
|
876
|
+
elif isinstance(axes, (list, tuple)):
|
|
877
|
+
_check_for(axes, ndim)
|
|
878
|
+
new_axes = ()
|
|
879
|
+
for item in axes:
|
|
880
|
+
if item not in new_axes:
|
|
881
|
+
new_axes += (item,)
|
|
882
|
+
axes = new_axes
|
|
883
|
+
else:
|
|
884
|
+
raise TypeError("For Tensor.squeeze, the parameter 'axes' must be one of [int, tuple, list], but got {}" \
|
|
885
|
+
.format(type(axes)))
|
|
886
|
+
|
|
887
|
+
def _check_axis(s, idx, axes, ndim):
|
|
888
|
+
# if an axis is selected with shape entry greater than one, an error is raised.
|
|
889
|
+
if s != 1 and ((idx in axes) or (idx - ndim in axes)):
|
|
890
|
+
raise ValueError(f"For Tensor.squeeze, the shape of parameter 'axis' {axes} must be 1, but got {s}.")
|
|
891
|
+
|
|
892
|
+
for idx, s in enumerate(shape):
|
|
893
|
+
_check_axis(s, idx, axes, ndim)
|
|
894
|
+
if s != 1 or (idx not in axes) and (idx - ndim not in axes):
|
|
895
|
+
new_shape = new_shape + (s,)
|
|
896
|
+
|
|
897
|
+
return new_shape
|
|
898
|
+
|
|
899
|
+
|
|
900
|
+
def check_axis_in_range(axis, ndim):
|
|
901
|
+
"""Checks axes are with the bounds of ndim"""
|
|
902
|
+
def _check():
|
|
816
903
|
if not isinstance(axis, int):
|
|
817
904
|
raise TypeError(f'The axes must be integers, but got {type(axis)}')
|
|
818
|
-
|
|
905
|
+
|
|
906
|
+
if axis >= ndim or axis < -ndim:
|
|
819
907
|
raise ValueError(f"The 'axis' must be in the range of [-{ndim}, {ndim}), but got {axis}.")
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
return
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
if type_tuple:
|
|
896
|
-
type_str += "tuple, "
|
|
897
|
-
if type_list:
|
|
898
|
-
type_str += "list, "
|
|
899
|
-
raise TypeError(f"For Tensor.ptp, the axis should be {type_str}, but got {type(axis)}.")
|
|
900
|
-
|
|
901
|
-
@staticmethod
|
|
902
|
-
def check_and_canonicalize_axes(axes, ndim):
|
|
903
|
-
"""Check whether the types and values of input axes are valid."""
|
|
904
|
-
axes = axes if isinstance(axes, tuple) else (axes,)
|
|
905
|
-
new_axes = ()
|
|
906
|
-
for ax in axes:
|
|
908
|
+
|
|
909
|
+
_check()
|
|
910
|
+
return (axis + ndim) % ndim
|
|
911
|
+
|
|
912
|
+
|
|
913
|
+
def check_axis_valid(axes, ndim):
|
|
914
|
+
"""
|
|
915
|
+
Checks axes are valid given ndim, and returns axes that can be passed
|
|
916
|
+
to the built-in operator (non-negative, int or tuple)
|
|
917
|
+
"""
|
|
918
|
+
def _check_range(axes):
|
|
919
|
+
for axis in axes:
|
|
920
|
+
check_axis_in_range(axis, ndim)
|
|
921
|
+
|
|
922
|
+
if axes is None:
|
|
923
|
+
axes = tuple(range(ndim))
|
|
924
|
+
return axes
|
|
925
|
+
if isinstance(axes, (tuple, list)):
|
|
926
|
+
_check_range(axes)
|
|
927
|
+
tmp = ()
|
|
928
|
+
for x in axes:
|
|
929
|
+
tmp = tmp + ((x + ndim) % ndim,)
|
|
930
|
+
_check_dup(tmp)
|
|
931
|
+
return tmp
|
|
932
|
+
check_axis_in_range(axes, ndim)
|
|
933
|
+
return (axes % ndim,)
|
|
934
|
+
|
|
935
|
+
|
|
936
|
+
def max_(*args):
|
|
937
|
+
return max(*args)
|
|
938
|
+
|
|
939
|
+
|
|
940
|
+
def min_(*args):
|
|
941
|
+
return min(*args)
|
|
942
|
+
|
|
943
|
+
|
|
944
|
+
def expanded_shape(ndim, axis_size, axis):
|
|
945
|
+
"""
|
|
946
|
+
Returns a shape with size = 1 for all dimensions
|
|
947
|
+
except at axis.
|
|
948
|
+
"""
|
|
949
|
+
return tuple(axis_size if i == axis else 1 for i in range(ndim))
|
|
950
|
+
|
|
951
|
+
|
|
952
|
+
def tuple_slice(tup, start, end):
|
|
953
|
+
"""get sliced tuple from start and end."""
|
|
954
|
+
return tup[start:end]
|
|
955
|
+
|
|
956
|
+
|
|
957
|
+
def infer_out_shape(*shapes):
|
|
958
|
+
"""
|
|
959
|
+
Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
|
|
960
|
+
"""
|
|
961
|
+
def _check(items, max_size, shapes):
|
|
962
|
+
for item in items:
|
|
963
|
+
if item not in (1, max_size):
|
|
964
|
+
raise ValueError(f'For Tensor, the dimension on each axis must be 1 or the max on the axis' \
|
|
965
|
+
f'to support broadcast, but got shapes {shapes,}')
|
|
966
|
+
shape_out = ()
|
|
967
|
+
max_len = max([len(it) for it in shapes])
|
|
968
|
+
for i in range(max_len):
|
|
969
|
+
items = [it[i-(max_len-len(it))] if i - (max_len - len(it))
|
|
970
|
+
>= 0 else 1 for it in shapes]
|
|
971
|
+
max_size = 0 if 0 in items else max(items)
|
|
972
|
+
_check(items, max_size, shapes)
|
|
973
|
+
shape_out = shape_out + (max_size,)
|
|
974
|
+
return shape_out
|
|
975
|
+
|
|
976
|
+
|
|
977
|
+
def check_axis_type(axis, type_int=True, type_tuple=True, type_list=True):
|
|
978
|
+
"""Check axis argument type."""
|
|
979
|
+
if type_int and isinstance(axis, int):
|
|
980
|
+
return True
|
|
981
|
+
if (type_tuple and isinstance(axis, tuple)) or (type_list and isinstance(axis, list)):
|
|
982
|
+
for ax in axis:
|
|
907
983
|
if not isinstance(ax, int):
|
|
908
|
-
raise TypeError(f"
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
def
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
984
|
+
raise TypeError(f"For Tensor.ptp, each axis must be integer, but got {type(ax)} in {axis}.")
|
|
985
|
+
return True
|
|
986
|
+
|
|
987
|
+
type_str = ""
|
|
988
|
+
if type_int:
|
|
989
|
+
type_str += "int, "
|
|
990
|
+
if type_tuple:
|
|
991
|
+
type_str += "tuple, "
|
|
992
|
+
if type_list:
|
|
993
|
+
type_str += "list, "
|
|
994
|
+
raise TypeError(f"For Tensor.ptp, the axis should be {type_str}, but got {type(axis)}.")
|
|
995
|
+
|
|
996
|
+
|
|
997
|
+
def check_and_canonicalize_axes(axes, ndim):
|
|
998
|
+
"""Check whether the types and values of input axes are valid."""
|
|
999
|
+
def _check(axes, ax, ndim):
|
|
1000
|
+
if not isinstance(ax, int):
|
|
1001
|
+
raise TypeError(f"Each axis should be integer, but got {type(ax)} in {axes}.")
|
|
1002
|
+
if ax >= ndim or ax < -ndim:
|
|
1003
|
+
raise ValueError(f"The 'axis' must be in the range of [-{ndim}, {ndim}), but got {ax}.")
|
|
1004
|
+
|
|
1005
|
+
axes = axes if isinstance(axes, tuple) else (axes,)
|
|
1006
|
+
new_axes = ()
|
|
1007
|
+
for ax in axes:
|
|
1008
|
+
_check(axes, ax, ndim)
|
|
1009
|
+
ax = ax if ax >= 0 else ax + ndim
|
|
1010
|
+
new_axes += (ax,)
|
|
1011
|
+
_check_dup(new_axes)
|
|
1012
|
+
return new_axes
|
|
1013
|
+
|
|
1014
|
+
|
|
1015
|
+
def check_type_support(dtype, device, supported_dtypes):
|
|
1016
|
+
"""Checks whether the data type is supported."""
|
|
1017
|
+
return dtype in supported_dtypes or not context.get_context('device_target') == device
|
|
1018
|
+
|
|
1019
|
+
|
|
1020
|
+
def check_sparse_tensor_input(indices, values, shape):
|
|
1021
|
+
"""Common input check for SparseTensors."""
|
|
1022
|
+
if not isinstance(indices, Tensor_) and not is_stub_tensor(indices):
|
|
1023
|
+
raise TypeError(f"For SparseTensors, 'indices' must be Tensor, but got {type(indices)}.")
|
|
1024
|
+
if not isinstance(values, Tensor_) and not is_stub_tensor(values):
|
|
1025
|
+
raise TypeError(f"For SparseTensors, 'values' must be Tensor, but got {type(values)}.")
|
|
1026
|
+
if not isinstance(shape, tuple):
|
|
1027
|
+
raise TypeError(f"For SparseTensors, 'shape' must be tuple, but got {type(shape)}.")
|
|
1028
|
+
|
|
1029
|
+
|
|
1030
|
+
def check_csr_tensor_input(indptr, indices, values, shape):
|
|
1031
|
+
"""Checks inputs type for CSRTensor."""
|
|
1032
|
+
if not isinstance(indptr, Tensor_) and not is_stub_tensor(indptr):
|
|
1033
|
+
raise TypeError(f"For CSRTensor, 'indptr' must be Tensor, but got {type(indptr)}.")
|
|
1034
|
+
check_sparse_tensor_input(indices, values, shape)
|
|
1035
|
+
|
|
1036
|
+
|
|
1037
|
+
def check_csr_tensor_shape(indptr_shp, indices_shp, values_shp, csr_shp):
|
|
1038
|
+
"""Checks input tensors' shapes for CSRTensor."""
|
|
1039
|
+
# Support empty sparse tensor
|
|
1040
|
+
if (indptr_shp == (0,)) and (indices_shp == (0,)) and (values_shp == (0,)):
|
|
1041
|
+
return
|
|
1042
|
+
shape_size = 1
|
|
1043
|
+
val_shp_size = 1
|
|
1044
|
+
for item in csr_shp:
|
|
1045
|
+
if item <= 0:
|
|
1046
|
+
raise ValueError(f"For CSRTensor, the element of shape must be positive, but got {item}")
|
|
1047
|
+
if not isinstance(item, int):
|
|
1048
|
+
raise TypeError(f"For CSRTensor, the element type of shape must be int, but got {type(item)}")
|
|
1049
|
+
shape_size *= item
|
|
1050
|
+
for item in values_shp:
|
|
1051
|
+
if item <= 0:
|
|
1052
|
+
raise ValueError(f"The element of shape must be positive, but got {item}")
|
|
1053
|
+
val_shp_size *= item
|
|
1054
|
+
if shape_size < val_shp_size:
|
|
1055
|
+
raise ValueError(f"Shape total size: {shape_size} is too small to hold {val_shp_size} non-zero values.")
|
|
1056
|
+
if len(indices_shp) != 1:
|
|
1057
|
+
raise ValueError(f"For CSRTensor, indices must be a 1-dimensional tensor, " \
|
|
1058
|
+
f"but got a {len(indices_shp)} dimension tensor.")
|
|
1059
|
+
if len(indptr_shp) != 1:
|
|
1060
|
+
raise ValueError(f"For CSRTensor, indptr must be a 1-dimensional tensor, " \
|
|
1061
|
+
f"but got a {len(indptr_shp)} dimension tensor.")
|
|
1062
|
+
if csr_shp[0] + 1 != indptr_shp[0]:
|
|
1063
|
+
raise ValueError(f"For CSRTensor, indptr must have length (1 + shape[0]), " \
|
|
1064
|
+
f"but got: {indptr_shp[0]}")
|
|
1065
|
+
if indices_shp[0] != values_shp[0]:
|
|
1066
|
+
err_msg1 = "For CSRTensor, indices and values must equal in their shape, "
|
|
1067
|
+
err_msg2 = f"but got indices shape: {indices_shp[0]}, values shape: {values_shp[0]}."
|
|
1068
|
+
raise ValueError(err_msg1 + err_msg2)
|
|
1069
|
+
if len(values_shp) + 1 != len(csr_shp):
|
|
1070
|
+
raise ValueError(f"Values' dimension should equal to CSRTensor's dimension - 1, but got" \
|
|
1071
|
+
f"Values' dimension: {len(values_shp)} , CSRTensor's dimension: " \
|
|
1072
|
+
f"{len(csr_shp)}")
|
|
1073
|
+
if values_shp[1:] != csr_shp[2:]:
|
|
1074
|
+
raise ValueError(f"CSRTensor's shape[2: ] must be equal to value's shape[1: ]," \
|
|
1075
|
+
f"but CSRTensor's shape[2: ] got: {csr_shp[2: ]} and value's shape[1: ]" \
|
|
1076
|
+
f"got: {values_shp[1: ]}")
|
|
1077
|
+
|
|
1078
|
+
|
|
1079
|
+
def check_csr_tensor_dtype(indptr_dtype, indices_dtype):
|
|
1080
|
+
"""Checks input tensors' data types for CSRTensor."""
|
|
1081
|
+
if indptr_dtype not in (mstype.int16, mstype.int32, mstype.int64):
|
|
1082
|
+
raise TypeError(f"For CSRTensor, indptr must have int16 or int32 or int64 data type, " \
|
|
1083
|
+
f"but got {indptr_dtype}.")
|
|
1084
|
+
if indices_dtype not in (mstype.int16, mstype.int32, mstype.int64):
|
|
1085
|
+
raise TypeError(f"For CSRTensor, indices must have int16 or int32 or int64 data type, " \
|
|
1086
|
+
f"but got {indices_dtype}.")
|
|
1087
|
+
|
|
1088
|
+
|
|
1089
|
+
def check_coo_tensor_input(indices, values, shape):
|
|
1090
|
+
"""Checks inputs type for COOTensor."""
|
|
1091
|
+
check_sparse_tensor_input(indices, values, shape)
|
|
1092
|
+
|
|
1093
|
+
|
|
1094
|
+
def check_coo_tensor_shape(indices_shp, values_shp, coo_shp):
|
|
1095
|
+
"""Checks input tensors' shapes for COOTensor."""
|
|
1096
|
+
if len(coo_shp) != 2:
|
|
1097
|
+
raise ValueError(f"For COOTensor, the length of 'shape' must be 2, but got {coo_shp}.")
|
|
1098
|
+
if (indices_shp == (0,)) and (values_shp == (0,)):
|
|
1099
|
+
return
|
|
1100
|
+
shp_mul = 1
|
|
1101
|
+
for sh in coo_shp:
|
|
1102
|
+
if sh <= 0:
|
|
1103
|
+
raise ValueError(f"For COOTensor, the element of 'shape' must be positive, but got {sh} in {coo_shp}.")
|
|
1104
|
+
if not isinstance(sh, int):
|
|
1105
|
+
raise TypeError(f"For COOTensor, the element type of 'shape' must be int, but got {type(sh)}")
|
|
1106
|
+
shp_mul *= sh
|
|
1107
|
+
if shp_mul < values_shp[0]:
|
|
1108
|
+
raise ValueError(f"For COOTensor, shape is too small: ({shp_mul}) to hold all values({values_shp[0]}).")
|
|
1109
|
+
if len(indices_shp) != 2:
|
|
1110
|
+
raise ValueError(f"For COOTensor, 'indices' must be a 2-dimensional tensor, but got a {len(indices_shp)}" \
|
|
1111
|
+
f"-dimensional tensor.")
|
|
1112
|
+
if len(values_shp) != 1:
|
|
1113
|
+
raise ValueError(f"For COOTensor, 'values' must be a 1-dimensional tensor, but got a {len(values_shp)}" \
|
|
1114
|
+
f"-dimensional tensor.")
|
|
1115
|
+
if indices_shp[0] != values_shp[0]:
|
|
1116
|
+
raise ValueError(f"For COOTensor, 'indices.shape[0]' must be euqal to 'values.shape[0]', but got " \
|
|
1117
|
+
f"'indices.shape[0]' = {indices_shp[0]} and 'values.shape[0]' = {values_shp[0]}.")
|
|
1118
|
+
if indices_shp[1] != 2:
|
|
1119
|
+
raise ValueError(f"For COOTensor, 'indices.shape[1]' must be 2, but got {indices_shp[1]}.")
|
|
1120
|
+
|
|
1121
|
+
|
|
1122
|
+
def check_coo_tensor_dtype(indices_dtype):
|
|
1123
|
+
"""Checks input tensors' data types for COOTensor."""
|
|
1124
|
+
if indices_dtype not in (mstype.int16, mstype.int32, mstype.int64):
|
|
1125
|
+
raise TypeError(f"For COOTensor, the type of 'indices' must be one of [int16, int32, int64], but got " \
|
|
1126
|
+
f"{indices_dtype}.")
|
|
1127
|
+
|
|
1128
|
+
|
|
1129
|
+
def check_dynamic_shape(dyn_elem, actual_input, i):
|
|
1130
|
+
"""Check the consistency of dynamic shape tensors and actual input tensors."""
|
|
1131
|
+
if dyn_elem.dtype != actual_input.dtype:
|
|
1132
|
+
raise TypeError(f"The data type of '{i}'th args in actual input tensors should be '{dyn_elem.dtype}', " \
|
|
1133
|
+
f"but got '{actual_input.dtype}'.")
|
|
1134
|
+
if dyn_elem.ndim != actual_input.ndim:
|
|
1135
|
+
raise ValueError(f"The dimension of '{i}'th args in actual input tensors should be '{dyn_elem.ndim}', " \
|
|
1136
|
+
f"but got '{actual_input.ndim}'.")
|
|
1137
|
+
check_dyn_shape_value_equal(i, dyn_elem.shape, actual_input.shape)
|
|
1138
|
+
|
|
1139
|
+
|
|
1140
|
+
def check_element_type_of_iterable(arg_name, arg_value, valid_types, prim_name=None):
|
|
1141
|
+
"""Check type of the element of a iterabel object, execpt dict."""
|
|
1142
|
+
check_value_type(arg_name, arg_value, [list, tuple], prim_name)
|
|
1143
|
+
type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types]
|
|
1144
|
+
num_types = len(valid_types)
|
|
1145
|
+
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
1146
|
+
for element in arg_value:
|
|
1147
|
+
if not isinstance(element, tuple(valid_types)):
|
|
1148
|
+
raise TypeError(f"{msg_prefix} type of '{arg_name}' should be {'one of ' if num_types > 1 else ''}" \
|
|
1149
|
+
f"{type_names if num_types > 1 else type_names[0]}, " \
|
|
1150
|
+
f"but got '{element}' with type '{type(element).__name__}'.")
|
|
1151
|
+
|
|
1152
|
+
|
|
1153
|
+
def check_element_type_of_dict(arg_name, arg_value, key_types, value_types, prim_name=None):
|
|
1154
|
+
"""Check the type of key and value of a dict."""
|
|
1155
|
+
check_value_type(arg_name, arg_value, [dict], prim_name)
|
|
1156
|
+
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
1157
|
+
type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in key_types]
|
|
1158
|
+
num_types = len(key_types)
|
|
1159
|
+
for element in arg_value.keys():
|
|
1160
|
+
if not isinstance(element, tuple(key_types)):
|
|
1161
|
+
raise TypeError(f"{msg_prefix} type of '{arg_name}' should be {'one of ' if num_types > 1 else ''}" \
|
|
1162
|
+
f"{type_names if num_types > 1 else type_names[0]}, " \
|
|
1163
|
+
f"but got '{element}' with type '{type(element).__name__}'.")
|
|
1164
|
+
|
|
1165
|
+
type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in value_types]
|
|
1166
|
+
num_types = len(value_types)
|
|
1167
|
+
for element in arg_value.values():
|
|
1168
|
+
if not isinstance(element, tuple(value_types)):
|
|
1169
|
+
raise TypeError(f"{msg_prefix} type of '{arg_name}' should be {'one of ' if num_types > 1 else ''}" \
|
|
1170
|
+
f"{type_names if num_types > 1 else type_names[0]}, " \
|
|
1171
|
+
f"but got '{element}' with type '{type(element).__name__}'.")
|
|
1172
|
+
|
|
1173
|
+
|
|
1174
|
+
def check_size_and_element_type_of_tuple(arg_name, arg_value, expect_size, expect_element_type, prim_name=None):
|
|
1175
|
+
"""Check the size and element type of a tuple."""
|
|
1176
|
+
check_value_type(arg_name, arg_value, [tuple], prim_name)
|
|
1177
|
+
check_equal_int(len(arg_value), expect_size, arg_name + ' size', prim_name)
|
|
1178
|
+
check_element_type_of_iterable('arg_name', arg_value, [expect_element_type], prim_name)
|
|
1086
1179
|
|
|
1087
1180
|
|
|
1088
1181
|
def check_dyn_shape_value_equal(index, dyn_shape, actual_shape):
|
|
1089
1182
|
"""Check the consistency of dynamic shape and actual input shape."""
|
|
1090
1183
|
for i, x in enumerate(dyn_shape):
|
|
1091
1184
|
if x not in (-1, actual_shape[i]):
|
|
1092
|
-
raise ValueError(f"The {i}th shape value of `{index}`th actual input args should be `{x}`, but got "
|
|
1185
|
+
raise ValueError(f"The {i}th shape value of `{index}`th actual input args should be `{x}`, but got " \
|
|
1093
1186
|
f"`{actual_shape[i]}`.")
|
|
1094
1187
|
|
|
1095
1188
|
|
|
@@ -1107,17 +1200,17 @@ def _expand_tuple(n_dimensions):
|
|
|
1107
1200
|
if not isinstance(m, tuple):
|
|
1108
1201
|
if isinstance(m, int) and not isinstance(m, bool):
|
|
1109
1202
|
return tuple(repeat(m, n_dimensions))
|
|
1110
|
-
raise TypeError(f"When expanding an int number to tuple, input type must be integer or tuple[int], "
|
|
1203
|
+
raise TypeError(f"When expanding an int number to tuple, input type must be integer or tuple[int], " \
|
|
1111
1204
|
f"but got {type(m)}")
|
|
1112
1205
|
|
|
1113
1206
|
if not len(m) is n_dimensions:
|
|
1114
|
-
raise TypeError(f"When expanding an int number to tuple, input tuple dimension must be {n_dimensions}, "
|
|
1207
|
+
raise TypeError(f"When expanding an int number to tuple, input tuple dimension must be {n_dimensions}, " \
|
|
1115
1208
|
f"but got {m}")
|
|
1116
1209
|
|
|
1117
1210
|
for i in m:
|
|
1118
1211
|
if not isinstance(i, int) or isinstance(i, bool):
|
|
1119
|
-
raise TypeError(f"When expanding an int number to tuple, "
|
|
1120
|
-
f"the type of element in input tuple must be
|
|
1212
|
+
raise TypeError(f"When expanding an int number to tuple, " \
|
|
1213
|
+
f"the type of element in input tuple must be an integer, but got {type(i)}.")
|
|
1121
1214
|
return m
|
|
1122
1215
|
|
|
1123
1216
|
return convert
|
|
@@ -1153,8 +1246,8 @@ def check_input_data(*data, data_class):
|
|
|
1153
1246
|
if not ret:
|
|
1154
1247
|
data_class_str = tuple(i.__name__ if hasattr(i, '__name__') else i for i in data_class) if isinstance(
|
|
1155
1248
|
data_class, (tuple, list)) else (data_class if data_class is None else data_class.__name__)
|
|
1156
|
-
raise TypeError(f'The type of input data must be in the Union({data_class_str}, '
|
|
1157
|
-
f'tuple[{data_class_str}], list[{data_class_str}], dict[{data_class_str}]), '
|
|
1249
|
+
raise TypeError(f'The type of input data must be in the Union({data_class_str}, ' \
|
|
1250
|
+
f'tuple[{data_class_str}], list[{data_class_str}], dict[{data_class_str}]), ' \
|
|
1158
1251
|
f'but got type {item if item is None else type(item).__name__}.')
|
|
1159
1252
|
|
|
1160
1253
|
|
|
@@ -1208,31 +1301,3 @@ def args_type_check(*type_args, **type_kwargs):
|
|
|
1208
1301
|
|
|
1209
1302
|
|
|
1210
1303
|
_set_record = {}
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
def args_unreset_check(*unreset_args, **unreset_kwargs):
|
|
1214
|
-
"""Check the entered non repeatable setting properties."""
|
|
1215
|
-
|
|
1216
|
-
def unreset_check(func):
|
|
1217
|
-
sig = inspect.signature(func)
|
|
1218
|
-
bound_unreset = sig.bind_partial(*unreset_args, **unreset_kwargs).arguments
|
|
1219
|
-
|
|
1220
|
-
@wraps(func)
|
|
1221
|
-
def wrapper(*args, **kwargs):
|
|
1222
|
-
nonlocal bound_unreset
|
|
1223
|
-
bound_values = sig.bind(*args, **kwargs)
|
|
1224
|
-
argument_dict = bound_values.arguments
|
|
1225
|
-
if "kwargs" in bound_unreset:
|
|
1226
|
-
bound_unreset = bound_unreset["kwargs"]
|
|
1227
|
-
if "kwargs" in argument_dict:
|
|
1228
|
-
argument_dict = argument_dict["kwargs"]
|
|
1229
|
-
for name, value in argument_dict.items():
|
|
1230
|
-
if name in _set_record.keys():
|
|
1231
|
-
raise TypeError("For 'set_context', the parameter '{}' can not be set repeatedly.".format(name))
|
|
1232
|
-
if name in bound_unreset:
|
|
1233
|
-
_set_record[name] = value
|
|
1234
|
-
return func(*args, **kwargs)
|
|
1235
|
-
|
|
1236
|
-
return wrapper
|
|
1237
|
-
|
|
1238
|
-
return unreset_check
|