mindspore 1.10.0__cp39-cp39-win_amd64.whl → 2.0.0rc1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/ConcurrencyCheck.dll +0 -0
- mindspore/CppBuildInsights.dll +0 -0
- mindspore/CppCoreCheck.dll +0 -0
- mindspore/EnumIndex.dll +0 -0
- mindspore/EspXEngine.dll +0 -0
- mindspore/HResultCheck.dll +0 -0
- mindspore/KernelTraceControl.dll +0 -0
- mindspore/LocalESPC.dll +0 -0
- mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
- mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
- mindspore/VariantClear.dll +0 -0
- mindspore/__init__.py +9 -4
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/builtin_operations.py +32 -4
- mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +17 -2
- mindspore/_extends/parse/parser.py +193 -34
- mindspore/_extends/parse/resources.py +7 -8
- mindspore/_extends/parse/standard_method.py +1780 -435
- mindspore/_extends/parse/trope.py +3 -1
- mindspore/amp.py +53 -58
- mindspore/atlprov.dll +0 -0
- mindspore/boost/adasum.py +3 -2
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +46 -26
- mindspore/boost/dim_reduce.py +6 -5
- mindspore/boost/grad_accumulation.py +2 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/cfgpersist.dll +0 -0
- mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
- mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
- mindspore/common/__init__.py +11 -10
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +57 -0
- mindspore/common/api.py +582 -297
- mindspore/common/dtype.py +66 -18
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +38 -1
- mindspore/common/jit_config.py +25 -13
- mindspore/common/mutable.py +53 -24
- mindspore/common/parameter.py +60 -37
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +927 -0
- mindspore/common/tensor.py +1627 -3900
- mindspore/communication/__init__.py +10 -5
- mindspore/communication/_comm_helper.py +78 -214
- mindspore/communication/_hccl_management.py +2 -1
- mindspore/communication/management.py +136 -47
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +291 -56
- mindspore/d3dcompiler_47.dll +0 -0
- mindspore/dataset/__init__.py +12 -8
- mindspore/dataset/audio/__init__.py +9 -9
- mindspore/dataset/audio/transforms.py +1090 -228
- mindspore/dataset/audio/utils.py +87 -39
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +17 -15
- mindspore/dataset/core/config.py +246 -17
- mindspore/dataset/core/py_util_helpers.py +4 -3
- mindspore/dataset/core/validator_helpers.py +10 -10
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +9 -9
- mindspore/dataset/engine/datasets.py +648 -477
- mindspore/dataset/engine/datasets_audio.py +165 -167
- mindspore/dataset/engine/datasets_standard_format.py +93 -67
- mindspore/dataset/engine/datasets_text.py +492 -342
- mindspore/dataset/engine/datasets_user_defined.py +85 -50
- mindspore/dataset/engine/datasets_vision.py +1224 -699
- mindspore/dataset/engine/graphdata.py +134 -69
- mindspore/dataset/engine/iterators.py +50 -9
- mindspore/dataset/engine/offload.py +52 -31
- mindspore/dataset/engine/samplers.py +27 -24
- mindspore/dataset/engine/serializer_deserializer.py +14 -15
- mindspore/dataset/engine/validators.py +213 -52
- mindspore/dataset/text/__init__.py +10 -8
- mindspore/dataset/text/transforms.py +152 -57
- mindspore/dataset/text/utils.py +98 -49
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +4 -2
- mindspore/dataset/transforms/c_transforms.py +11 -13
- mindspore/dataset/transforms/py_transforms.py +2 -2
- mindspore/dataset/transforms/py_transforms_util.py +10 -0
- mindspore/dataset/transforms/transforms.py +13 -15
- mindspore/dataset/transforms/validators.py +7 -7
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/browse_dataset.py +13 -13
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +8 -7
- mindspore/dataset/vision/c_transforms.py +125 -126
- mindspore/dataset/vision/py_transforms.py +37 -37
- mindspore/dataset/vision/py_transforms_util.py +23 -20
- mindspore/dataset/vision/transforms.py +316 -315
- mindspore/dataset/vision/utils.py +313 -17
- mindspore/dataset/vision/validators.py +6 -6
- mindspore/default_config.py +0 -1
- mindspore/dpcmi.dll +0 -0
- mindspore/{compression → experimental}/__init__.py +6 -5
- mindspore/experimental/map_parameter.py +275 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +70 -9
- mindspore/include/api/delegate.h +8 -1
- mindspore/include/api/dual_abi_helper.h +8 -24
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_group.h +68 -0
- mindspore/include/api/model_parallel_runner.h +17 -17
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +20 -4
- mindspore/include/api/status.h +7 -1
- mindspore/include/api/types.h +25 -21
- mindspore/include/api/visible.h +4 -0
- mindspore/include/c_api/model_c.h +5 -0
- mindspore/include/c_api/status_c.h +1 -1
- mindspore/include/dataset/config.h +1 -1
- mindspore/include/dataset/constants.h +14 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/include/dataset/vision.h +56 -117
- mindspore/include/dataset/vision_lite.h +102 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +28 -28
- mindspore/mindrecord/common/exceptions.py +2 -4
- mindspore/mindrecord/filereader.py +19 -1
- mindspore/mindrecord/filewriter.py +250 -88
- mindspore/mindrecord/mindpage.py +13 -13
- mindspore/mindrecord/shardheader.py +15 -15
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +29 -29
- mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
- mindspore/mindrecord/tools/csv_to_mr.py +4 -4
- mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
- mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/{libmindspore_backend.dll → mindspore_backend.dll} +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +1 -5
- mindspore/nn/cell.py +297 -234
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +17 -42
- mindspore/nn/layer/__init__.py +7 -4
- mindspore/nn/layer/activation.py +131 -88
- mindspore/nn/layer/basic.py +313 -613
- mindspore/nn/layer/channel_shuffle.py +103 -0
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +52 -6
- mindspore/nn/layer/conv.py +112 -43
- mindspore/nn/layer/dense.py +10 -9
- mindspore/nn/layer/embedding.py +36 -34
- mindspore/nn/layer/image.py +123 -27
- mindspore/nn/layer/math.py +108 -107
- mindspore/nn/layer/normalization.py +212 -366
- mindspore/nn/layer/padding.py +370 -42
- mindspore/nn/layer/pooling.py +1443 -219
- mindspore/nn/layer/rnn_cells.py +11 -16
- mindspore/nn/layer/rnns.py +38 -39
- mindspore/nn/layer/thor_layer.py +24 -25
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +9 -6
- mindspore/nn/loss/loss.py +678 -142
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
- mindspore/nn/optim/ada_grad.py +8 -8
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +18 -14
- mindspore/nn/optim/adam.py +429 -87
- mindspore/nn/optim/adamax.py +5 -6
- mindspore/nn/optim/adasum.py +10 -8
- mindspore/nn/optim/asgd.py +7 -7
- mindspore/nn/optim/ftrl.py +81 -11
- mindspore/nn/optim/lamb.py +7 -8
- mindspore/nn/optim/lars.py +4 -4
- mindspore/nn/optim/lazyadam.py +82 -7
- mindspore/nn/optim/momentum.py +8 -7
- mindspore/nn/optim/optimizer.py +19 -10
- mindspore/nn/optim/proximal_ada_grad.py +6 -5
- mindspore/nn/optim/rmsprop.py +3 -3
- mindspore/nn/optim/rprop.py +20 -16
- mindspore/nn/optim/sgd.py +21 -15
- mindspore/nn/optim/thor.py +23 -21
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -6
- mindspore/nn/probability/bijector/invert.py +4 -2
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/__init__.py +6 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
- mindspore/nn/probability/distribution/_utils/utils.py +11 -17
- mindspore/nn/probability/distribution/bernoulli.py +6 -6
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +9 -9
- mindspore/nn/probability/distribution/cauchy.py +8 -8
- mindspore/nn/probability/distribution/distribution.py +12 -6
- mindspore/nn/probability/distribution/exponential.py +5 -5
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +6 -5
- mindspore/nn/probability/distribution/gumbel.py +5 -5
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +0 -1
- mindspore/nn/probability/distribution/logistic.py +4 -5
- mindspore/nn/probability/distribution/normal.py +11 -15
- mindspore/nn/probability/distribution/poisson.py +6 -2
- mindspore/nn/probability/distribution/student_t.py +150 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +5 -5
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +8 -1
- mindspore/nn/wrap/cell_wrapper.py +55 -27
- mindspore/nn/wrap/grad_reducer.py +20 -11
- mindspore/nn/wrap/loss_scale.py +47 -30
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +46 -42
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +26 -19
- mindspore/numpy/utils.py +1 -8
- mindspore/numpy/utils_const.py +112 -62
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -3
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +209 -152
- mindspore/ops/_grad/grad_base.py +55 -17
- mindspore/ops/_grad/grad_clip_ops.py +11 -3
- mindspore/ops/_grad/grad_comm_ops.py +58 -47
- mindspore/ops/_grad/grad_implementations.py +21 -61
- mindspore/ops/_grad/grad_inner_ops.py +48 -6
- mindspore/ops/_grad/grad_math_ops.py +306 -161
- mindspore/ops/_grad/grad_nn_ops.py +192 -181
- mindspore/ops/_grad/grad_other_ops.py +1 -1
- mindspore/ops/_grad/grad_quant_ops.py +5 -5
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +15 -9
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
- mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
- mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
- mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
- mindspore/ops/_op_impl/__init__.py +3 -3
- mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
- mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
- mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/__init__.py +1 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
- mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -608
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/greater.py +2 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
- mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
- mindspore/ops/_op_impl/tbe/slice.py +26 -15
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +3 -2
- mindspore/ops/_register_for_op.py +11 -0
- mindspore/ops/_utils/__init__.py +1 -1
- mindspore/ops/_utils/utils.py +20 -41
- mindspore/ops/_vmap/__init__.py +2 -2
- mindspore/ops/_vmap/vmap_array_ops.py +170 -78
- mindspore/ops/_vmap/vmap_base.py +24 -10
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
- mindspore/ops/_vmap/vmap_image_ops.py +52 -0
- mindspore/ops/_vmap/vmap_math_ops.py +77 -6
- mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
- mindspore/ops/_vmap/vmap_other_ops.py +3 -1
- mindspore/ops/_vmap/vmap_random_ops.py +55 -3
- mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
- mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/__init__.py +1 -4
- mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
- mindspore/ops/composite/__init__.py +12 -13
- mindspore/ops/composite/base.py +261 -254
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +197 -156
- mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
- mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
- mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
- mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
- mindspore/ops/function/__init__.py +323 -8
- mindspore/ops/function/array_func.py +3511 -780
- mindspore/ops/function/clip_func.py +329 -0
- mindspore/ops/function/debug_func.py +6 -6
- mindspore/ops/function/grad/__init__.py +5 -1
- mindspore/ops/function/grad/grad_func.py +736 -65
- mindspore/ops/function/image_func.py +270 -0
- mindspore/ops/function/linalg_func.py +268 -8
- mindspore/ops/function/math_func.py +8032 -3164
- mindspore/ops/function/nn_func.py +5619 -1855
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +11 -10
- mindspore/ops/function/random_func.py +939 -77
- mindspore/ops/function/sparse_func.py +249 -84
- mindspore/ops/function/sparse_unary_func.py +2303 -0
- mindspore/ops/function/spectral_func.py +146 -0
- mindspore/ops/function/vmap_func.py +114 -0
- mindspore/ops/functional.py +182 -254
- mindspore/ops/op_info_register.py +79 -34
- mindspore/ops/operations/__init__.py +210 -118
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +25 -15
- mindspore/ops/operations/_grad_ops.py +447 -322
- mindspore/ops/operations/_inner_ops.py +547 -176
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +29 -27
- mindspore/ops/operations/_ocr_ops.py +11 -11
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_quant_ops.py +186 -101
- mindspore/ops/operations/_rl_inner_ops.py +122 -61
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1047 -0
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +4 -4
- mindspore/ops/operations/array_ops.py +1428 -1226
- mindspore/ops/operations/comm_ops.py +180 -117
- mindspore/ops/operations/control_ops.py +4 -2
- mindspore/ops/operations/custom_ops.py +185 -98
- mindspore/ops/operations/debug_ops.py +92 -54
- mindspore/ops/operations/image_ops.py +406 -211
- mindspore/ops/operations/inner_ops.py +42 -53
- mindspore/ops/operations/linalg_ops.py +32 -29
- mindspore/ops/operations/math_ops.py +2076 -897
- mindspore/ops/operations/nn_ops.py +1282 -1252
- mindspore/ops/operations/other_ops.py +124 -278
- mindspore/ops/operations/random_ops.py +345 -178
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +502 -157
- mindspore/ops/operations/spectral_ops.py +107 -0
- mindspore/ops/primitive.py +192 -15
- mindspore/ops/vm_impl_registry.py +23 -2
- mindspore/parallel/__init__.py +6 -1
- mindspore/parallel/_auto_parallel_context.py +199 -92
- mindspore/parallel/_cell_wrapper.py +4 -2
- mindspore/parallel/_cost_model_context.py +3 -0
- mindspore/parallel/_dp_allreduce_fusion.py +2 -1
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +167 -28
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +9 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
- mindspore/parallel/_utils.py +47 -7
- mindspore/parallel/algo_parameter_config.py +5 -1
- mindspore/parallel/checkpoint_transform.py +329 -0
- mindspore/parallel/shard.py +229 -0
- mindspore/perf_msvcbuildinsights.dll +0 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/util.py +4 -3
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +249 -0
- mindspore/profiler/parser/aicpu_data_parser.py +38 -39
- mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
- mindspore/profiler/parser/base_timeline_generator.py +471 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
- mindspore/profiler/parser/framework_parser.py +42 -16
- mindspore/profiler/parser/hccl_parser.py +158 -158
- mindspore/profiler/parser/hwts_log_parser.py +7 -6
- mindspore/profiler/parser/integrator.py +18 -1579
- mindspore/profiler/parser/minddata_analyzer.py +8 -8
- mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +108 -0
- mindspore/profiler/parser/step_trace_parser.py +1 -1
- mindspore/profiler/profiling.py +396 -194
- mindspore/rewrite/__init__.py +6 -2
- mindspore/rewrite/api/node.py +51 -110
- mindspore/rewrite/api/node_type.py +10 -6
- mindspore/rewrite/api/pattern_engine.py +51 -7
- mindspore/rewrite/api/scoped_value.py +64 -53
- mindspore/rewrite/api/symbol_tree.py +108 -61
- mindspore/rewrite/api/tree_node_helper.py +2 -3
- mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
- mindspore/rewrite/ast_helpers/__init__.py +6 -3
- mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
- mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
- mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
- mindspore/rewrite/ast_transformers/__init__.py +0 -1
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
- mindspore/rewrite/common/__init__.py +2 -0
- mindspore/rewrite/common/event.py +1 -1
- mindspore/rewrite/common/observable.py +1 -1
- mindspore/rewrite/common/observer.py +1 -1
- mindspore/rewrite/common/rewrite_elog.py +35 -0
- mindspore/rewrite/namer.py +2 -2
- mindspore/rewrite/namespace.py +14 -4
- mindspore/rewrite/node.py +161 -13
- mindspore/rewrite/parser.py +0 -1
- mindspore/rewrite/parser_register.py +0 -1
- mindspore/rewrite/parsers/arguments_parser.py +3 -2
- mindspore/rewrite/parsers/assign_parser.py +267 -67
- mindspore/rewrite/parsers/attribute_parser.py +56 -0
- mindspore/rewrite/parsers/class_def_parser.py +191 -108
- mindspore/rewrite/parsers/constant_parser.py +101 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/for_parser.py +28 -15
- mindspore/rewrite/parsers/function_def_parser.py +21 -5
- mindspore/rewrite/parsers/if_parser.py +11 -28
- mindspore/rewrite/parsers/module_parser.py +9 -6
- mindspore/rewrite/parsers/return_parser.py +3 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +322 -109
- mindspore/rewrite/symbol_tree_builder.py +45 -8
- mindspore/rewrite/symbol_tree_dumper.py +0 -1
- mindspore/rewrite/topological_manager.py +1 -2
- mindspore/run_check/_check_version.py +209 -112
- mindspore/run_check/run_check.py +2 -1
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +6 -4
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +321 -50
- mindspore/train/callback/__init__.py +3 -1
- mindspore/train/callback/_backup_and_restore.py +120 -0
- mindspore/train/callback/_callback.py +8 -8
- mindspore/train/callback/_checkpoint.py +12 -9
- mindspore/train/callback/_early_stop.py +13 -7
- mindspore/train/callback/_history.py +8 -8
- mindspore/train/callback/_lambda_callback.py +6 -6
- mindspore/train/callback/_landscape.py +36 -38
- mindspore/train/callback/_loss_monitor.py +12 -6
- mindspore/train/callback/_lr_scheduler_callback.py +2 -4
- mindspore/train/callback/_on_request_exit.py +212 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
- mindspore/train/callback/_summary_collector.py +27 -19
- mindspore/train/callback/_time_monitor.py +13 -7
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +122 -33
- mindspore/train/dataset_helper.py +28 -87
- mindspore/train/loss_scale_manager.py +4 -7
- mindspore/{nn → train}/metrics/__init__.py +20 -20
- mindspore/{nn → train}/metrics/accuracy.py +12 -10
- mindspore/{nn → train}/metrics/auc.py +4 -4
- mindspore/{nn → train}/metrics/bleu_score.py +4 -4
- mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
- mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
- mindspore/{nn → train}/metrics/dice.py +6 -5
- mindspore/{nn → train}/metrics/error.py +7 -5
- mindspore/{nn → train}/metrics/fbeta.py +9 -7
- mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
- mindspore/{nn → train}/metrics/loss.py +4 -3
- mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/metric.py +6 -5
- mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
- mindspore/{nn → train}/metrics/perplexity.py +5 -4
- mindspore/{nn → train}/metrics/precision.py +5 -4
- mindspore/{nn → train}/metrics/recall.py +5 -4
- mindspore/{nn → train}/metrics/roc.py +7 -6
- mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/topk.py +7 -5
- mindspore/train/mind_ir_pb2.py +339 -32
- mindspore/train/model.py +113 -84
- mindspore/train/serialization.py +547 -167
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -12
- mindspore/train/train_thor/convert_utils.py +7 -1
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/train/train_thor/model_thor.py +0 -4
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +901 -660
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -514
- mindspore/compression/quant/qat.py +0 -636
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/libatomic-1.dll +0 -0
- mindspore/libgcc_s_seh-1.dll +0 -0
- mindspore/libgfortran-4.dll +0 -0
- mindspore/libgomp-1.dll +0 -0
- mindspore/libjpeg-62.dll +0 -0
- mindspore/libmindspore.dll +0 -0
- mindspore/libmindspore_common.dll +0 -0
- mindspore/libmindspore_core.dll +0 -0
- mindspore/libmindspore_glog.dll +0 -0
- mindspore/libnnacl.dll +0 -0
- mindspore/libopencv_core452.dll +0 -0
- mindspore/libopencv_imgcodecs452.dll +0 -0
- mindspore/libopencv_imgproc452.dll +0 -0
- mindspore/libquadmath-0.dll +0 -0
- mindspore/libsqlite3.dll +0 -0
- mindspore/libssp-0.dll +0 -0
- mindspore/libstdc++-6.dll +0 -0
- mindspore/libtinyxml2.dll +0 -0
- mindspore/libturbojpeg.dll +0 -0
- mindspore/libwinpthread-1.dll +0 -0
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -138
- mindspore/nn/probability/dpn/vae/vae.py +0 -122
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
- mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
- mindspore/ops/composite/array_ops.py +0 -210
- mindspore/ops/composite/clip_ops.py +0 -238
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/ops/operations/sponge_ops.py +0 -3531
- mindspore/ops/operations/sponge_update_ops.py +0 -2546
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- mindspore/run_check/_check_deps_version.py +0 -84
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
|
2
|
+
#
|
|
3
|
+
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
# ============================================================================
|
|
17
|
+
|
|
18
|
+
"""Env related operations."""
|
|
19
|
+
from __future__ import absolute_import
|
|
20
|
+
from mindspore.ops.composite.base import MultitypeFuncGraph
|
|
21
|
+
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
22
|
+
from mindspore.ops.primitive import Primitive
|
|
23
|
+
from mindspore.ops.operations import _grad_ops
|
|
24
|
+
from mindspore.ops import operations as P
|
|
25
|
+
|
|
26
|
+
env_get = MultitypeFuncGraph("env_get")
|
|
27
|
+
environ_get = Primitive('EnvironGet')
|
|
28
|
+
ref_to_embed = _grad_ops.RefToEmbed()
|
|
29
|
+
tensor_zeros_like = P.ZerosLike()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@env_get.register("EnvType", "Tensor")
|
|
33
|
+
def _tensor_env_get(env, parameter):
|
|
34
|
+
"""Used to get env."""
|
|
35
|
+
return environ_get(env, ref_to_embed(parameter), tensor_zeros_like(parameter))
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@env_get.register("EnvType", "MapTensor")
|
|
39
|
+
def _map_tensor_env_get(env, map_parameter):
|
|
40
|
+
"""Used to get env for map parameter."""
|
|
41
|
+
return environ_get(env, ref_to_embed(map_parameter), zeros_like(map_parameter))
|
|
@@ -13,22 +13,22 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""math Operations."""
|
|
16
|
-
import numpy as np
|
|
17
16
|
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
|
|
18
17
|
from mindspore.common import dtype as mstype
|
|
19
|
-
from mindspore
|
|
20
|
-
from mindspore.ops.primitive import constexpr
|
|
18
|
+
from mindspore import _checkparam as validator
|
|
19
|
+
from mindspore.ops.primitive import constexpr, _primexpr
|
|
21
20
|
from mindspore.ops import functional as F
|
|
22
|
-
from mindspore.ops.
|
|
23
|
-
from
|
|
24
|
-
from .. import operations as P
|
|
21
|
+
from mindspore.ops.function.math_func import cummin as cummin_
|
|
22
|
+
from mindspore.ops import operations as P
|
|
25
23
|
|
|
26
24
|
|
|
27
|
-
@
|
|
25
|
+
@_primexpr
|
|
28
26
|
def _check_validate_axis(axis, name):
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
27
|
+
def _check(axis):
|
|
28
|
+
if isinstance(axis, (tuple, list)):
|
|
29
|
+
for idx, item in enumerate(axis):
|
|
30
|
+
validator.check_value_type("axis[%d]" % idx, item, [int], name)
|
|
31
|
+
_check(axis)
|
|
32
32
|
axis = validator.check_value_type('axis', axis, [int, tuple, list], name)
|
|
33
33
|
return axis
|
|
34
34
|
|
|
@@ -46,20 +46,26 @@ def is_const(x):
|
|
|
46
46
|
|
|
47
47
|
def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
|
|
48
48
|
r"""
|
|
49
|
-
Count number of nonzero elements across axis of input tensor
|
|
49
|
+
Count number of nonzero elements across axis of input tensor.
|
|
50
50
|
|
|
51
51
|
Args:
|
|
52
|
-
x (Tensor): Input data is used to count non-zero numbers.
|
|
53
|
-
|
|
54
|
-
axis (Union[int, tuple(int), list(int)]): The dimensions to reduce.
|
|
55
|
-
|
|
56
|
-
keep_dims (bool):
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
52
|
+
x (Tensor): Input data is used to count non-zero numbers. With shape
|
|
53
|
+
:math:`(N,*)` where :math:`*` means, any number of additional dimensions.
|
|
54
|
+
axis (Union[int, tuple(int), list(int)], optional): The dimensions to reduce.
|
|
55
|
+
Default: (), reduce all dimensions.
|
|
56
|
+
keep_dims (bool, optional): Whether to maintain dimensions specified by `axis`.
|
|
57
|
+
If true, keep these reduced dimensions and the length is 1.
|
|
58
|
+
If false, don't keep these dimensions. Default: False.
|
|
59
|
+
dtype (Union[Number, mindspore.bool\_], optional): The data type of the output tensor.
|
|
60
|
+
Default: mindspore.int32.
|
|
60
61
|
|
|
61
62
|
Returns:
|
|
62
|
-
Tensor, number of nonzero element
|
|
63
|
+
Tensor, number of nonzero element across axis specified by `axis`.
|
|
64
|
+
The data type is specified by `dtype`.
|
|
65
|
+
|
|
66
|
+
Raises:
|
|
67
|
+
TypeError: If `axis` is not int, tuple or list.
|
|
68
|
+
ValueError: If any value in `axis` is not in range [-x.ndim, x.ndim).
|
|
63
69
|
|
|
64
70
|
Supported Platforms:
|
|
65
71
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -102,7 +108,9 @@ def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
|
|
|
102
108
|
not_equal = P.NotEqual()
|
|
103
109
|
cast = P.Cast()
|
|
104
110
|
reduce_sum = P.ReduceSum(keep_dims)
|
|
105
|
-
|
|
111
|
+
zeros = P.Zeros()
|
|
112
|
+
tensor_0 = zeros(x.shape, x.dtype)
|
|
113
|
+
nonzero_bool = not_equal(x, tensor_0)
|
|
106
114
|
# ReduceSum only support float16 or float32 tensor.
|
|
107
115
|
nonzero_val = cast(nonzero_bool, mstype.float32)
|
|
108
116
|
nonzero_num = cast(reduce_sum(nonzero_val, axis), dtype)
|
|
@@ -110,7 +118,7 @@ def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
|
|
|
110
118
|
return nonzero_num
|
|
111
119
|
|
|
112
120
|
|
|
113
|
-
@
|
|
121
|
+
@_primexpr
|
|
114
122
|
def _int_to_tuple_conv(axes):
|
|
115
123
|
"""
|
|
116
124
|
Converts ints to tuples in input axes, expected by most validation checks.
|
|
@@ -121,7 +129,7 @@ def _int_to_tuple_conv(axes):
|
|
|
121
129
|
return axes
|
|
122
130
|
|
|
123
131
|
|
|
124
|
-
@
|
|
132
|
+
@_primexpr
|
|
125
133
|
def _check_axes(axes, prim_name=None):
|
|
126
134
|
"""
|
|
127
135
|
Check for validity and type of axes passed to function.
|
|
@@ -154,21 +162,29 @@ def _typecheck_input(x1_type, x2_type, prim_name=None):
|
|
|
154
162
|
f"and x2_type: {x2_type}.")
|
|
155
163
|
|
|
156
164
|
|
|
157
|
-
@
|
|
165
|
+
@_primexpr
|
|
158
166
|
def _axes_int_check(x1_shape, x2_shape, axes, prim_name=None):
|
|
159
167
|
"""
|
|
160
168
|
Convert from single int axes to 2d tuple if required
|
|
161
169
|
"""
|
|
162
170
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
163
|
-
|
|
171
|
+
|
|
172
|
+
def _check_lt_zero(axes):
|
|
164
173
|
if axes < 0:
|
|
165
174
|
raise ValueError(f"{msg_prefix} 'axes' must be at least 0, but got {axes}.")
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
return [], []
|
|
175
|
+
|
|
176
|
+
def _check_len(axes, x1_shape, x2_shape):
|
|
169
177
|
if axes > len(x1_shape) or axes > len(x2_shape):
|
|
170
178
|
raise ValueError(f"{msg_prefix} 'axes' cannot be greater than the length of 'x1_shape' and 'x2_shape', "
|
|
171
179
|
f"but got 'axes': {axes}, 'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}.")
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
if isinstance(axes, int):
|
|
183
|
+
_check_lt_zero(axes)
|
|
184
|
+
if axes == 0:
|
|
185
|
+
# outer product, no input validation required
|
|
186
|
+
return [], []
|
|
187
|
+
_check_len(axes, x1_shape, x2_shape)
|
|
172
188
|
x1_ind = tuple(range(len(x1_shape))[-1 * axes:])
|
|
173
189
|
x2_ind = tuple(range(len(x2_shape))[:axes])
|
|
174
190
|
axes = tuple((x1_ind, x2_ind))
|
|
@@ -176,7 +192,7 @@ def _axes_int_check(x1_shape, x2_shape, axes, prim_name=None):
|
|
|
176
192
|
return axes
|
|
177
193
|
|
|
178
194
|
|
|
179
|
-
@
|
|
195
|
+
@_primexpr
|
|
180
196
|
def _validate_axes(x1_shape, x2_shape, axes, prim_name=None):
|
|
181
197
|
"""
|
|
182
198
|
Checks for axes having the correct length according to input, for any value in axis
|
|
@@ -184,25 +200,32 @@ def _validate_axes(x1_shape, x2_shape, axes, prim_name=None):
|
|
|
184
200
|
with given inputs.
|
|
185
201
|
"""
|
|
186
202
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
203
|
+
|
|
204
|
+
def _check_len(axes_len, shape_dim_len, x_axes):
|
|
205
|
+
if axes_len > shape_dim_len:
|
|
206
|
+
raise ValueError(f"{msg_prefix} length of element {x_axes} in 'axes' must be less than or equal to "
|
|
207
|
+
f"{shape_dim_len}, but got {axes_len}.")
|
|
208
|
+
|
|
209
|
+
def _check_value(x_axes, min_val, max_val):
|
|
210
|
+
for _, x_value in enumerate(x_axes):
|
|
211
|
+
if x_value > max_val or x_value < min_val:
|
|
212
|
+
raise ValueError(f"{msg_prefix} value in 'axes' must be in range: [{min_val}, {max_val}], "
|
|
213
|
+
f"but got {x_value}.")
|
|
214
|
+
|
|
187
215
|
shapes = [x1_shape, x2_shape]
|
|
188
216
|
|
|
189
217
|
# axis length check
|
|
190
218
|
for ix_input, x_axes in enumerate(axes):
|
|
191
219
|
axes_len = len(x_axes)
|
|
192
220
|
shape_dim_len = len(shapes[ix_input])
|
|
193
|
-
|
|
194
|
-
raise ValueError(f"{msg_prefix} length of element {x_axes} in 'axes' must be less than or equal to "
|
|
195
|
-
f"{shape_dim_len}, but got {axes_len}.")
|
|
221
|
+
_check_len(axes_len, shape_dim_len, x_axes)
|
|
196
222
|
|
|
197
223
|
# axis values range check
|
|
198
224
|
for ix_input, x_axes in enumerate(axes):
|
|
199
225
|
comp_shape = shapes[ix_input]
|
|
200
226
|
max_val = len(comp_shape) - 1
|
|
201
227
|
min_val = -1 * len(comp_shape)
|
|
202
|
-
|
|
203
|
-
if not min_val <= x_value <= max_val:
|
|
204
|
-
raise ValueError(f"{msg_prefix} value in 'axes' must be in range: [{min_val}, {max_val}], "
|
|
205
|
-
f"but got {x_value}.")
|
|
228
|
+
_check_value(x_axes, min_val, max_val)
|
|
206
229
|
|
|
207
230
|
# check axis value with input shape - both ways for axis valid
|
|
208
231
|
invalid_a = False
|
|
@@ -212,23 +235,31 @@ def _validate_axes(x1_shape, x2_shape, axes, prim_name=None):
|
|
|
212
235
|
invalid_a = True
|
|
213
236
|
if x1_shape[axes[0][i]] != x2_shape[axes[1][len(axes[0]) - 1 - i]]:
|
|
214
237
|
invalid_b = True
|
|
215
|
-
if invalid_a and invalid_b:
|
|
216
|
-
raise ValueError(f"{msg_prefix} 'i' should exist such that 'x1_shape[axes[0][i]]' is equal to "
|
|
217
|
-
f"'x2_shape[axes[1][i]]' or 'x2_shape[axes[1][len(axes[0])-1-i]]', but got "
|
|
218
|
-
f"'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}, 'axes': {axes}.")
|
|
219
238
|
|
|
239
|
+
def _check(invalid_a, invalid_b, x1_shape, x2_shape, axes):
|
|
240
|
+
if invalid_a and invalid_b:
|
|
241
|
+
raise ValueError(f"{msg_prefix} 'i' should exist such that 'x1_shape[axes[0][i]]' is equal to "
|
|
242
|
+
f"'x2_shape[axes[1][i]]' or 'x2_shape[axes[1][len(axes[0])-1-i]]', but got "
|
|
243
|
+
f"'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}, 'axes': {axes}.")
|
|
220
244
|
|
|
221
|
-
|
|
245
|
+
_check(invalid_a, invalid_b, x1_shape, x2_shape, axes)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@_primexpr
|
|
222
249
|
def _calc_new_shape(shape, axes, position=0):
|
|
223
250
|
"""
|
|
224
251
|
Calculate transpose and reshape parameters for input transformations,
|
|
225
252
|
'position' refers to whether tensor is first or second in the op.
|
|
226
253
|
"""
|
|
227
254
|
contraction_axes = tuple(i if i >= 0 else i + len(shape) for i in axes[position])
|
|
228
|
-
prod_contraction =
|
|
255
|
+
prod_contraction = 1
|
|
256
|
+
for i in contraction_axes:
|
|
257
|
+
prod_contraction *= shape[i]
|
|
229
258
|
free_axes = tuple(i for i in range(len(shape)) if i not in contraction_axes)
|
|
230
|
-
free_dims = tuple(shape[i] for i in free_axes)
|
|
231
|
-
prod_free =
|
|
259
|
+
free_dims = tuple(shape[i] if shape[i] is not None else -1 for i in free_axes)
|
|
260
|
+
prod_free = 1
|
|
261
|
+
for free_dim in free_dims:
|
|
262
|
+
prod_free *= free_dim
|
|
232
263
|
|
|
233
264
|
transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes
|
|
234
265
|
new_shape = (prod_contraction, prod_free) if position else (prod_free, prod_contraction)
|
|
@@ -245,7 +276,7 @@ def tensor_dot(x1, x2, axes):
|
|
|
245
276
|
|
|
246
277
|
Selected dims in both inputs must also match.
|
|
247
278
|
|
|
248
|
-
axes = 0 leads to outer product
|
|
279
|
+
axes = 0 leads to outer product.
|
|
249
280
|
axes = 1 leads to normal matrix multiplication when inputs both 2D.
|
|
250
281
|
axes = 1 is the same as axes = ((1,),(0,)) where both `a` and `b` are 2D.
|
|
251
282
|
axes = 2 is the same as axes = ((1,2),(0,1)) where both `a` and `b` are 3D.
|
|
@@ -288,10 +319,7 @@ def tensor_dot(x1, x2, axes):
|
|
|
288
319
|
# input validity checks
|
|
289
320
|
x1_shape = shape_op(x1)
|
|
290
321
|
x2_shape = shape_op(x2)
|
|
291
|
-
x1_type = F.dtype(x1)
|
|
292
|
-
x2_type = F.dtype(x2)
|
|
293
322
|
axes = _check_axes(axes, 'tensor_dot')
|
|
294
|
-
_typecheck_input(x1_type, x2_type, 'tensor_dot')
|
|
295
323
|
# input compatibility check & axes format update
|
|
296
324
|
axes = _axes_int_check(x1_shape, x2_shape, axes, 'tensor_dot')
|
|
297
325
|
_validate_axes(x1_shape, x2_shape, axes, 'tensor_dot')
|
|
@@ -308,7 +336,7 @@ def tensor_dot(x1, x2, axes):
|
|
|
308
336
|
return final_result
|
|
309
337
|
|
|
310
338
|
|
|
311
|
-
@
|
|
339
|
+
@_primexpr
|
|
312
340
|
def _check_invalid_input(x1_shape, x2_shape, prim_name=None):
|
|
313
341
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
314
342
|
if len(x1_shape) < 2 or len(x2_shape) < 2:
|
|
@@ -329,56 +357,57 @@ def _typecheck_input_dot(x1_type, x2_type, prim_name=None):
|
|
|
329
357
|
f"x1_type: {x1_type} and x2_type: {x2_type}.")
|
|
330
358
|
|
|
331
359
|
|
|
332
|
-
@
|
|
360
|
+
@_primexpr
|
|
333
361
|
def _get_transpose_shape(x2_shape):
|
|
334
362
|
x2_shape_range = tuple(range(len(x2_shape)))
|
|
335
363
|
x2_shape_transpose = x2_shape_range[-2:-1] + x2_shape_range[:-2] + x2_shape_range[-1:]
|
|
336
364
|
return x2_shape_transpose
|
|
337
365
|
|
|
338
366
|
|
|
339
|
-
def dot(
|
|
367
|
+
def dot(input, other):
|
|
340
368
|
"""
|
|
341
369
|
Computation a dot product between samples in two tensors.
|
|
342
370
|
|
|
343
371
|
Args:
|
|
344
|
-
|
|
372
|
+
input (Tensor): First tensor in Dot op with datatype float16 or float32,
|
|
345
373
|
The rank must be greater than or equal to 2.
|
|
346
|
-
|
|
374
|
+
other (Tensor): Second tensor in Dot op with datatype float16 or float32,
|
|
347
375
|
The rank must be greater than or equal to 2.
|
|
348
376
|
|
|
349
377
|
Returns:
|
|
350
|
-
Tensor, dot product of
|
|
378
|
+
Tensor, dot product of input and other.
|
|
351
379
|
|
|
352
380
|
Raises:
|
|
353
|
-
TypeError: If type of
|
|
354
|
-
TypeError: If dtype of
|
|
355
|
-
ValueError: If rank of
|
|
381
|
+
TypeError: If type of input and other are not the same.
|
|
382
|
+
TypeError: If dtype of input or other is not float16 or float32.
|
|
383
|
+
ValueError: If rank of input or other less than 2.
|
|
356
384
|
|
|
357
385
|
Supported Platforms:
|
|
358
386
|
``Ascend`` ``GPU`` ``CPU``
|
|
359
387
|
|
|
360
388
|
Examples:
|
|
389
|
+
>>> import numpy as np
|
|
361
390
|
>>> import mindspore
|
|
362
391
|
>>> from mindspore import Tensor, ops
|
|
363
|
-
>>>
|
|
364
|
-
>>>
|
|
365
|
-
>>> output = ops.dot(
|
|
392
|
+
>>> input = Tensor(np.ones(shape=[2, 3]), mindspore.float32)
|
|
393
|
+
>>> other = Tensor(np.ones(shape=[1, 3, 2]), mindspore.float32)
|
|
394
|
+
>>> output = ops.dot(input, other)
|
|
366
395
|
>>> print(output)
|
|
367
396
|
[[[3. 3.]]
|
|
368
397
|
[[3. 3.]]]
|
|
369
398
|
>>> print(output.shape)
|
|
370
399
|
(2, 1, 2)
|
|
371
|
-
>>>
|
|
372
|
-
>>>
|
|
373
|
-
>>> output = ops.dot(
|
|
400
|
+
>>> input = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
|
|
401
|
+
>>> other = Tensor(np.ones(shape=[1, 3, 2]), mindspore.float32)
|
|
402
|
+
>>> output = ops.dot(input, other)
|
|
374
403
|
>>> print(output)
|
|
375
404
|
[[[[3. 3.]]
|
|
376
405
|
[[3. 3.]]]]
|
|
377
406
|
>>> print(output.shape)
|
|
378
407
|
(1, 2, 1, 2)
|
|
379
|
-
>>>
|
|
380
|
-
>>>
|
|
381
|
-
>>> output = ops.dot(
|
|
408
|
+
>>> input = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
|
|
409
|
+
>>> other = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
|
|
410
|
+
>>> output = ops.dot(input, other)
|
|
382
411
|
>>> print(output)
|
|
383
412
|
[[[[3. 3.]
|
|
384
413
|
[3. 3.]]
|
|
@@ -386,9 +415,9 @@ def dot(x1, x2):
|
|
|
386
415
|
[3. 3.]]]]
|
|
387
416
|
>>> print(output.shape)
|
|
388
417
|
(1, 2, 2, 2)
|
|
389
|
-
>>>
|
|
390
|
-
>>>
|
|
391
|
-
>>> output = ops.dot(
|
|
418
|
+
>>> input = Tensor(np.ones(shape=[3, 2, 3]), mindspore.float32)
|
|
419
|
+
>>> other = Tensor(np.ones(shape=[2, 1, 3, 2]), mindspore.float32)
|
|
420
|
+
>>> output = ops.dot(input, other)
|
|
392
421
|
>>> print(output)
|
|
393
422
|
[[[[[3. 3.]]
|
|
394
423
|
[[3. 3.]]]
|
|
@@ -409,34 +438,36 @@ def dot(x1, x2):
|
|
|
409
438
|
reshape_op = P.Reshape()
|
|
410
439
|
transpose_op = P.Transpose()
|
|
411
440
|
matmul_op = P.MatMul(False, False)
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
_typecheck_input_dot(
|
|
417
|
-
_check_invalid_input(
|
|
418
|
-
|
|
419
|
-
if len(
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
mul_result = matmul_op(
|
|
425
|
-
reshape_shape =
|
|
441
|
+
input_shape = shape_op(input)
|
|
442
|
+
other_shape = shape_op(other)
|
|
443
|
+
input_type = F.dtype(input)
|
|
444
|
+
other_type = F.dtype(other)
|
|
445
|
+
_typecheck_input_dot(input_type, other_type, 'dot')
|
|
446
|
+
_check_invalid_input(input_shape, other_shape, 'dot')
|
|
447
|
+
|
|
448
|
+
if len(input_shape) > 2 or len(other_shape) > 2:
|
|
449
|
+
other_shape_transpose = _get_transpose_shape(other_shape)
|
|
450
|
+
other_transpose = transpose_op(other, other_shape_transpose)
|
|
451
|
+
input_reshape = reshape_op(input, (-1, input_shape[-1]))
|
|
452
|
+
other_reshape = reshape_op(other_transpose, (other_shape[-2], -1))
|
|
453
|
+
mul_result = matmul_op(input_reshape, other_reshape)
|
|
454
|
+
reshape_shape = input_shape[:-1] + other_shape[:-2] + other_shape[-1:]
|
|
426
455
|
reshape_shape = (-1,) + reshape_shape[1:]
|
|
427
456
|
return reshape_op(mul_result, reshape_shape)
|
|
428
|
-
return matmul_op(
|
|
457
|
+
return matmul_op(input, other)
|
|
429
458
|
|
|
430
459
|
|
|
431
|
-
@
|
|
460
|
+
@_primexpr
|
|
432
461
|
def _get_batch_size(x1_shape, x2_shape, prim_name=None):
|
|
433
462
|
"""
|
|
434
463
|
Get batch sizes from two inputs
|
|
435
464
|
"""
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
465
|
+
def _check():
|
|
466
|
+
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
467
|
+
if len(x1_shape) < 2 or len(x2_shape) < 2:
|
|
468
|
+
raise ValueError(f"{msg_prefix} inputs x1, x2 should have 'dimension >= 2', "
|
|
469
|
+
f"but got 'len(x1_shape)': ({len(x1_shape)}) and 'len(x2_shape)': ({len(x2_shape)}).")
|
|
470
|
+
_check()
|
|
440
471
|
return x1_shape[0], x2_shape[0]
|
|
441
472
|
|
|
442
473
|
|
|
@@ -453,12 +484,33 @@ def _typecheck_input_batch_dot(x1_type, x2_type, prim_name=None):
|
|
|
453
484
|
f"x2_type: {x2_type}.")
|
|
454
485
|
|
|
455
486
|
|
|
456
|
-
@
|
|
487
|
+
@_primexpr
|
|
457
488
|
def _check_axes_for_batch_dot(x1_shape, x2_shape, axes, prim_name=None):
|
|
458
489
|
"""
|
|
459
490
|
Check whether axes are valid and cast axes from tuple to list
|
|
460
491
|
"""
|
|
461
492
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
493
|
+
|
|
494
|
+
def _check_1(axes):
|
|
495
|
+
if 0 in axes:
|
|
496
|
+
raise ValueError(f"{msg_prefix} 'axes' cannot contain 0, but got axes: {axes}.")
|
|
497
|
+
if len(axes) != 2:
|
|
498
|
+
raise ValueError(f"{msg_prefix} length of 'axes' must be equal to 2, but got {len(axes)}.")
|
|
499
|
+
|
|
500
|
+
def _check_2(axes, x1_shape, x2_shape):
|
|
501
|
+
if axes[0] > len(x1_shape) or axes[1] > len(x2_shape):
|
|
502
|
+
raise ValueError(f"{msg_prefix} axes[0] must be less than or equal to len(x1_shape), "
|
|
503
|
+
f"and axes[1] must be less than or equal to len(x2_shape)."
|
|
504
|
+
f"But got 'axes': {axes}, 'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}.")
|
|
505
|
+
|
|
506
|
+
def _check_3(axes, x1_shape, x2_shape):
|
|
507
|
+
if axes == 0:
|
|
508
|
+
raise ValueError(f"{msg_prefix} 'axes' should not be equal to 0, but got {axes}.")
|
|
509
|
+
|
|
510
|
+
if axes > len(x1_shape) or axes > len(x2_shape):
|
|
511
|
+
raise ValueError(f"{msg_prefix} 'axes' cannot be greater than the length of 'x1_shape' and 'x2_shape', "
|
|
512
|
+
f"but got 'axes': {axes}, 'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}.")
|
|
513
|
+
|
|
462
514
|
if axes is None:
|
|
463
515
|
if len(x2_shape) == 2:
|
|
464
516
|
axes = [len(x1_shape) - 1, len(x2_shape) - 1]
|
|
@@ -466,10 +518,7 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes, prim_name=None):
|
|
|
466
518
|
axes = [len(x1_shape) - 1, len(x2_shape) - 2]
|
|
467
519
|
|
|
468
520
|
if isinstance(axes, (list, tuple)):
|
|
469
|
-
|
|
470
|
-
raise ValueError(f"{msg_prefix} 'axes' cannot contain 0, but got axes: {axes}.")
|
|
471
|
-
if len(axes) != 2:
|
|
472
|
-
raise ValueError(f"{msg_prefix} length of 'axes' must be equal to 2, but got {len(axes)}.")
|
|
521
|
+
_check_1(axes)
|
|
473
522
|
if isinstance(axes, tuple):
|
|
474
523
|
axes = list(axes)
|
|
475
524
|
validator.check_value_type('axes[0]', axes[0], [int], 'batch_dot')
|
|
@@ -481,19 +530,12 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes, prim_name=None):
|
|
|
481
530
|
axes[1] += len(x2_shape)
|
|
482
531
|
validator.check_non_negative_int(axes[0], 'reversed axes[0]', 'batch_dot')
|
|
483
532
|
validator.check_non_negative_int(axes[1], 'reversed axes[1]', 'batch_dot')
|
|
484
|
-
|
|
485
|
-
raise ValueError(f"{msg_prefix} axes[0] must be less than or equal to len(x1_shape), "
|
|
486
|
-
f"and axes[1] must be less than or equal to len(x2_shape)."
|
|
487
|
-
f"But got 'axes': {axes}, 'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}.")
|
|
533
|
+
_check_2(axes, x1_shape, x2_shape)
|
|
488
534
|
elif isinstance(axes, int):
|
|
489
|
-
|
|
490
|
-
raise ValueError(f"{msg_prefix} 'axes' should not be equal to 0, but got {axes}.")
|
|
535
|
+
_check_3(axes, x1_shape, x2_shape)
|
|
491
536
|
if axes < 0:
|
|
492
537
|
axes = [axes + len(x1_shape), axes + len(x2_shape)]
|
|
493
538
|
validator.check_non_negative_int(axes[0], 'reversed axes', 'batch_dot')
|
|
494
|
-
elif axes > len(x1_shape) or axes > len(x2_shape):
|
|
495
|
-
raise ValueError(f"{msg_prefix} 'axes' cannot be greater than the length of 'x1_shape' and 'x2_shape', "
|
|
496
|
-
f"but got 'axes': {axes}, 'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}.")
|
|
497
539
|
else:
|
|
498
540
|
axes = [axes, axes]
|
|
499
541
|
else:
|
|
@@ -502,7 +544,7 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes, prim_name=None):
|
|
|
502
544
|
return axes
|
|
503
545
|
|
|
504
546
|
|
|
505
|
-
@
|
|
547
|
+
@_primexpr
|
|
506
548
|
def _calc_new_shape_batchdot(shape, axes, position=0):
|
|
507
549
|
"""
|
|
508
550
|
Calculate transpose and reshape parameters for input transformations,
|
|
@@ -510,10 +552,14 @@ def _calc_new_shape_batchdot(shape, axes, position=0):
|
|
|
510
552
|
"""
|
|
511
553
|
axis = axes[position]
|
|
512
554
|
contraction_axes = tuple([axis])
|
|
513
|
-
prod_contraction =
|
|
555
|
+
prod_contraction = 1
|
|
556
|
+
for i in contraction_axes:
|
|
557
|
+
prod_contraction *= shape[i]
|
|
514
558
|
free_axes = tuple(i for i in range(1, len(shape)) if i not in contraction_axes)
|
|
515
559
|
free_dims = tuple(shape[i] for i in free_axes)
|
|
516
|
-
prod_free =
|
|
560
|
+
prod_free = 1
|
|
561
|
+
for free_dim in free_dims:
|
|
562
|
+
prod_free *= free_dim
|
|
517
563
|
|
|
518
564
|
transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes
|
|
519
565
|
transpose_perm = tuple([0]) + transpose_perm
|
|
@@ -522,7 +568,7 @@ def _calc_new_shape_batchdot(shape, axes, position=0):
|
|
|
522
568
|
return new_shape, transpose_perm, free_dims
|
|
523
569
|
|
|
524
570
|
|
|
525
|
-
@
|
|
571
|
+
@_primexpr
|
|
526
572
|
def _check_batch_size(x1_batch_size, x2_batch_size, prim_name=None):
|
|
527
573
|
"""
|
|
528
574
|
Check whether batch size of two inputs are the same
|
|
@@ -533,7 +579,7 @@ def _check_batch_size(x1_batch_size, x2_batch_size, prim_name=None):
|
|
|
533
579
|
f"'x1_batch_size': {x1_batch_size} and 'x2_batch_size': {x2_batch_size}.")
|
|
534
580
|
|
|
535
581
|
|
|
536
|
-
@
|
|
582
|
+
@_primexpr
|
|
537
583
|
def _get_output_shape(batch_size, x1_ret, x2_ret):
|
|
538
584
|
"""
|
|
539
585
|
Compute output shape for batch dot
|
|
@@ -725,6 +771,49 @@ def matmul(x1, x2, dtype=None):
|
|
|
725
771
|
return res
|
|
726
772
|
|
|
727
773
|
|
|
774
|
+
def mm(input, mat2):
|
|
775
|
+
r"""
|
|
776
|
+
Returns the matrix product of two arrays.
|
|
777
|
+
If `input` is a :math:`(n \times m)` Tensor, `mat2` is a
|
|
778
|
+
:math:`(m \times p)` Tensor, `out` will be a :math:`(n \times p)` Tensor.
|
|
779
|
+
|
|
780
|
+
Note:
|
|
781
|
+
This function cannot support broadcasting.
|
|
782
|
+
Refer to :func:`mindspore.ops.matmul` instead if you need a broadcastable function.
|
|
783
|
+
|
|
784
|
+
Args:
|
|
785
|
+
input (Tensor): The first matrix of matrix multiplication.
|
|
786
|
+
The last dimension of `input` must be the same size as the first dimension of `mat2`.
|
|
787
|
+
mat2 (Tensor): The second matrix of matrix multiplication.
|
|
788
|
+
The last dimension of `input` must be the same size as the first dimension of `mat2`.
|
|
789
|
+
|
|
790
|
+
Returns:
|
|
791
|
+
Tensor or scalar, the matrix product of the inputs.
|
|
792
|
+
|
|
793
|
+
Raises:
|
|
794
|
+
ValueError: If the last dimension of `input` is not the same size as the
|
|
795
|
+
second-to-last dimension of `mat2`.
|
|
796
|
+
ValueError: If `input` or `mat2` is not a matrix.
|
|
797
|
+
|
|
798
|
+
Supported Platforms:
|
|
799
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
800
|
+
|
|
801
|
+
Examples:
|
|
802
|
+
>>> import mindspore as ms
|
|
803
|
+
>>> import mindspore.ops as ops
|
|
804
|
+
>>> import numpy as np
|
|
805
|
+
>>> x1 = ms.Tensor(np.random.rand(2, 3))
|
|
806
|
+
>>> x2 = ms.Tensor(np.random.rand(3, 4))
|
|
807
|
+
>>> out = ops.mm(x1, x2)
|
|
808
|
+
>>> print(out.shape)
|
|
809
|
+
(2, 4)
|
|
810
|
+
"""
|
|
811
|
+
if input.ndim != 2 or mat2.ndim != 2:
|
|
812
|
+
raise ValueError(f"For mm, the input tensor must be a matrix, "
|
|
813
|
+
f"but got mat1.ndim:{input.ndim}, mat2.ndim:{mat2.ndim}")
|
|
814
|
+
return matmul(input, mat2)
|
|
815
|
+
|
|
816
|
+
|
|
728
817
|
def cummin(x, axis):
|
|
729
818
|
r"""
|
|
730
819
|
Returns a tuple (values,indices) where 'values' is the cumulative minimum value of input Tensor `x`
|
|
@@ -763,51 +852,3 @@ def cummin(x, axis):
|
|
|
763
852
|
[0 1 1 1 4 4]
|
|
764
853
|
"""
|
|
765
854
|
return cummin_(x, axis)
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
def resize_nearest_neighbor(input_x, size, align_corners=False):
|
|
769
|
-
r"""
|
|
770
|
-
Resizes the input tensor by using the nearest neighbor algorithm.
|
|
771
|
-
|
|
772
|
-
Resizes the input tensor to a given size by using the nearest neighbor algorithm. The nearest
|
|
773
|
-
neighbor algorithm selects the value of the nearest point and does not consider the
|
|
774
|
-
values of neighboring points at all, yielding a piecewise-constant interpolant.
|
|
775
|
-
|
|
776
|
-
Args:
|
|
777
|
-
input_x (Tensor) - The input tensor. The shape of the tensor is :math:`(N, C, H, W)`.
|
|
778
|
-
size (Union[Tensor, tuple, list]): The target size. The dimension of size must be 2.
|
|
779
|
-
align_corners (bool): Whether the centers of the 4 corner pixels of the input
|
|
780
|
-
and output tensors are aligned. Default: False.
|
|
781
|
-
|
|
782
|
-
Returns:
|
|
783
|
-
Tensor, the shape of the output tensor is :math:`(N, C, NEW\_H, NEW\_W)`.
|
|
784
|
-
The data type is the same as the `input_x`.
|
|
785
|
-
|
|
786
|
-
Raises:
|
|
787
|
-
TypeError: If `input_x` is not a Tensor.
|
|
788
|
-
TypeError: If `size` is neither tuple nor list.
|
|
789
|
-
TypeError: If `align_corners` is not a bool.
|
|
790
|
-
ValueError: If length of `size` is not equal to 2.
|
|
791
|
-
|
|
792
|
-
Supported Platforms:
|
|
793
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
794
|
-
|
|
795
|
-
Examples:
|
|
796
|
-
>>> input_tensor = Tensor(np.array([[[[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]]]), mindspore.float32)
|
|
797
|
-
>>> output = ops.ResizeNearestNeighbor(input_tensor, (2, 2))
|
|
798
|
-
>>> print(output)
|
|
799
|
-
[[[[-0.1 0.3]
|
|
800
|
-
[ 0.4 0.5]]]]
|
|
801
|
-
"""
|
|
802
|
-
if size is None:
|
|
803
|
-
raise ValueError(f'For ResizeNearestNeighbor, size could not be None.')
|
|
804
|
-
if isinstance(size, (tuple, list)):
|
|
805
|
-
resize = P.ResizeNearestNeighbor(size, align_corners)
|
|
806
|
-
return resize(input_x)
|
|
807
|
-
if is_const(size):
|
|
808
|
-
size = size.asnumpy()
|
|
809
|
-
resize = P.ResizeNearestNeighbor(size, align_corners)
|
|
810
|
-
return resize(input_x)
|
|
811
|
-
|
|
812
|
-
resize = DynamicResizeNearestNeighbor(align_corners)
|
|
813
|
-
return resize(input_x, size)
|