mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
|
@@ -13,699 +13,10 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""math Operations."""
|
|
16
|
-
|
|
17
|
-
from mindspore.common import dtype as mstype
|
|
18
|
-
from mindspore import _checkparam as validator
|
|
19
|
-
from mindspore.ops.primitive import constexpr, _primexpr
|
|
16
|
+
import mindspore.ops as ops
|
|
20
17
|
from mindspore.ops import functional as F
|
|
21
18
|
from mindspore.ops.function.math_func import cummin as cummin_
|
|
22
|
-
from mindspore.ops import
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
@_primexpr
|
|
26
|
-
def _check_validate_axis(axis, name):
|
|
27
|
-
def _check(axis):
|
|
28
|
-
if isinstance(axis, (tuple, list)):
|
|
29
|
-
for idx, item in enumerate(axis):
|
|
30
|
-
validator.check_value_type("axis[%d]" % idx, item, [int], name)
|
|
31
|
-
_check(axis)
|
|
32
|
-
axis = validator.check_value_type('axis', axis, [int, tuple, list], name)
|
|
33
|
-
return axis
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
@constexpr
|
|
37
|
-
def _check_validate_keepdims(keep_dims, name):
|
|
38
|
-
keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], name)
|
|
39
|
-
return keep_dims
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
@constexpr
|
|
43
|
-
def is_const(x):
|
|
44
|
-
return x is not None
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
|
|
48
|
-
r"""
|
|
49
|
-
Count number of nonzero elements across axis of input tensor.
|
|
50
|
-
|
|
51
|
-
Args:
|
|
52
|
-
x (Tensor): Input data is used to count non-zero numbers. With shape
|
|
53
|
-
:math:`(N,*)` where :math:`*` means, any number of additional dimensions.
|
|
54
|
-
axis (Union[int, tuple(int), list(int)], optional): The dimensions to reduce.
|
|
55
|
-
Default: (), reduce all dimensions.
|
|
56
|
-
keep_dims (bool, optional): Whether to maintain dimensions specified by `axis`.
|
|
57
|
-
If true, keep these reduced dimensions and the length is 1.
|
|
58
|
-
If false, don't keep these dimensions. Default: False.
|
|
59
|
-
dtype (Union[Number, mindspore.bool\_], optional): The data type of the output tensor.
|
|
60
|
-
Default: mindspore.int32.
|
|
61
|
-
|
|
62
|
-
Returns:
|
|
63
|
-
Tensor, number of nonzero element across axis specified by `axis`.
|
|
64
|
-
The data type is specified by `dtype`.
|
|
65
|
-
|
|
66
|
-
Raises:
|
|
67
|
-
TypeError: If `axis` is not int, tuple or list.
|
|
68
|
-
ValueError: If any value in `axis` is not in range [-x.ndim, x.ndim).
|
|
69
|
-
|
|
70
|
-
Supported Platforms:
|
|
71
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
72
|
-
|
|
73
|
-
Examples:
|
|
74
|
-
>>> from mindspore import Tensor, ops
|
|
75
|
-
>>> import numpy as np
|
|
76
|
-
>>> # case 1: each value specified.
|
|
77
|
-
>>> x = Tensor(np.array([[0, 1, 0], [1, 1, 0]]).astype(np.float32))
|
|
78
|
-
>>> nonzero_num = ops.count_nonzero(x=x, axis=[0, 1], keep_dims=True, dtype=mindspore.int32)
|
|
79
|
-
>>> print(nonzero_num)
|
|
80
|
-
[[3]]
|
|
81
|
-
>>> # case 2: all value is default.
|
|
82
|
-
>>> nonzero_num = ops.count_nonzero(x=x)
|
|
83
|
-
>>> print(nonzero_num)
|
|
84
|
-
3
|
|
85
|
-
>>> # case 3: axis value was specified 0.
|
|
86
|
-
>>> nonzero_num = ops.count_nonzero(x=x, axis=[0,])
|
|
87
|
-
>>> print(nonzero_num)
|
|
88
|
-
[1 2 0]
|
|
89
|
-
>>> # case 4: axis value was specified 1.
|
|
90
|
-
>>> nonzero_num = ops.count_nonzero(x=x, axis=[1,])
|
|
91
|
-
>>> print(nonzero_num)
|
|
92
|
-
[1 2]
|
|
93
|
-
>>> # case 5: keep_dims value was specified.
|
|
94
|
-
>>> nonzero_num = ops.count_nonzero(x=x, keep_dims=True)
|
|
95
|
-
>>> print(nonzero_num)
|
|
96
|
-
[[3]]
|
|
97
|
-
>>> # case 6: keep_dims and axis value was specified.
|
|
98
|
-
>>> nonzero_num = ops.count_nonzero(x=x, axis=[0,], keep_dims=True)
|
|
99
|
-
>>> print(nonzero_num)
|
|
100
|
-
[[1 2 0]]
|
|
101
|
-
"""
|
|
102
|
-
|
|
103
|
-
const_utils.check_type_valid(F.dtype(x), mstype.number_type, 'input x')
|
|
104
|
-
axis = _check_validate_axis(axis, "count_nonzero")
|
|
105
|
-
keep_dims = _check_validate_keepdims(keep_dims, "count_nonzero")
|
|
106
|
-
const_utils.check_type_valid(dtype, mstype.number_type + (mstype.bool_,), 'dtype')
|
|
107
|
-
|
|
108
|
-
not_equal = P.NotEqual()
|
|
109
|
-
cast = P.Cast()
|
|
110
|
-
reduce_sum = P.ReduceSum(keep_dims)
|
|
111
|
-
zeros = P.Zeros()
|
|
112
|
-
tensor_0 = zeros(x.shape, x.dtype)
|
|
113
|
-
nonzero_bool = not_equal(x, tensor_0)
|
|
114
|
-
# ReduceSum only support float16 or float32 tensor.
|
|
115
|
-
nonzero_val = cast(nonzero_bool, mstype.float32)
|
|
116
|
-
nonzero_num = cast(reduce_sum(nonzero_val, axis), dtype)
|
|
117
|
-
|
|
118
|
-
return nonzero_num
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
@_primexpr
|
|
122
|
-
def _int_to_tuple_conv(axes):
|
|
123
|
-
"""
|
|
124
|
-
Converts ints to tuples in input axes, expected by most validation checks.
|
|
125
|
-
"""
|
|
126
|
-
for x in [0, 1]:
|
|
127
|
-
if isinstance(axes[x], int):
|
|
128
|
-
axes[x] = (axes[x],)
|
|
129
|
-
return axes
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
@_primexpr
|
|
133
|
-
def _check_axes(axes, prim_name=None):
|
|
134
|
-
"""
|
|
135
|
-
Check for validity and type of axes passed to function.
|
|
136
|
-
"""
|
|
137
|
-
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
138
|
-
validator.check_value_type('axes', axes, [int, tuple, list], "tensor dot")
|
|
139
|
-
if not isinstance(axes, int):
|
|
140
|
-
axes = list(axes) # to avoid immutability issues
|
|
141
|
-
if len(axes) != 2:
|
|
142
|
-
raise ValueError(f"{msg_prefix} dimension of 'axes' must be 2, but got 'axes': {axes}.")
|
|
143
|
-
axes = _int_to_tuple_conv(axes) # convert before length checks
|
|
144
|
-
if len(axes[0]) != len(axes[1]):
|
|
145
|
-
raise ValueError(f"{msg_prefix} first and second dim of 'axes' have to be the same size/length, "
|
|
146
|
-
f"but got 'axes': {axes}.")
|
|
147
|
-
if len(axes[0]) != len(set(axes[0])) or len(axes[1]) != len(set(axes[1])):
|
|
148
|
-
raise ValueError(f"{msg_prefix} 'axes' cannot have duplicating values, but got {axes}.")
|
|
149
|
-
return axes
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
@constexpr
|
|
153
|
-
def _typecheck_input(x1_type, x2_type, prim_name=None):
|
|
154
|
-
"""
|
|
155
|
-
Check input tensor types to be valid and confirm they are the same type.
|
|
156
|
-
"""
|
|
157
|
-
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
158
|
-
const_utils.check_type_valid(x1_type, [mstype.float32, mstype.float16], 'x1')
|
|
159
|
-
const_utils.check_type_valid(x2_type, [mstype.float32, mstype.float16], 'x2')
|
|
160
|
-
if x1_type != x2_type:
|
|
161
|
-
raise TypeError(f"{msg_prefix} inputs must be the same type, but got x1_type: {x1_type} "
|
|
162
|
-
f"and x2_type: {x2_type}.")
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
@_primexpr
|
|
166
|
-
def _axes_int_check(x1_shape, x2_shape, axes, prim_name=None):
|
|
167
|
-
"""
|
|
168
|
-
Convert from single int axes to 2d tuple if required
|
|
169
|
-
"""
|
|
170
|
-
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
171
|
-
|
|
172
|
-
def _check_lt_zero(axes):
|
|
173
|
-
if axes < 0:
|
|
174
|
-
raise ValueError(f"{msg_prefix} 'axes' must be at least 0, but got {axes}.")
|
|
175
|
-
|
|
176
|
-
def _check_len(axes, x1_shape, x2_shape):
|
|
177
|
-
if axes > len(x1_shape) or axes > len(x2_shape):
|
|
178
|
-
raise ValueError(f"{msg_prefix} 'axes' cannot be greater than the length of 'x1_shape' and 'x2_shape', "
|
|
179
|
-
f"but got 'axes': {axes}, 'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}.")
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
if isinstance(axes, int):
|
|
183
|
-
_check_lt_zero(axes)
|
|
184
|
-
if axes == 0:
|
|
185
|
-
# outer product, no input validation required
|
|
186
|
-
return [], []
|
|
187
|
-
_check_len(axes, x1_shape, x2_shape)
|
|
188
|
-
x1_ind = tuple(range(len(x1_shape))[-1 * axes:])
|
|
189
|
-
x2_ind = tuple(range(len(x2_shape))[:axes])
|
|
190
|
-
axes = tuple((x1_ind, x2_ind))
|
|
191
|
-
axes = _int_to_tuple_conv(axes)
|
|
192
|
-
return axes
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
@_primexpr
|
|
196
|
-
def _validate_axes(x1_shape, x2_shape, axes, prim_name=None):
|
|
197
|
-
"""
|
|
198
|
-
Checks for axes having the correct length according to input, for any value in axis
|
|
199
|
-
being out of range with given shape and also checking for compatible axes values
|
|
200
|
-
with given inputs.
|
|
201
|
-
"""
|
|
202
|
-
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
203
|
-
|
|
204
|
-
def _check_len(axes_len, shape_dim_len, x_axes):
|
|
205
|
-
if axes_len > shape_dim_len:
|
|
206
|
-
raise ValueError(f"{msg_prefix} length of element {x_axes} in 'axes' must be less than or equal to "
|
|
207
|
-
f"{shape_dim_len}, but got {axes_len}.")
|
|
208
|
-
|
|
209
|
-
def _check_value(x_axes, min_val, max_val):
|
|
210
|
-
for _, x_value in enumerate(x_axes):
|
|
211
|
-
if x_value > max_val or x_value < min_val:
|
|
212
|
-
raise ValueError(f"{msg_prefix} value in 'axes' must be in range: [{min_val}, {max_val}], "
|
|
213
|
-
f"but got {x_value}.")
|
|
214
|
-
|
|
215
|
-
shapes = [x1_shape, x2_shape]
|
|
216
|
-
|
|
217
|
-
# axis length check
|
|
218
|
-
for ix_input, x_axes in enumerate(axes):
|
|
219
|
-
axes_len = len(x_axes)
|
|
220
|
-
shape_dim_len = len(shapes[ix_input])
|
|
221
|
-
_check_len(axes_len, shape_dim_len, x_axes)
|
|
222
|
-
|
|
223
|
-
# axis values range check
|
|
224
|
-
for ix_input, x_axes in enumerate(axes):
|
|
225
|
-
comp_shape = shapes[ix_input]
|
|
226
|
-
max_val = len(comp_shape) - 1
|
|
227
|
-
min_val = -1 * len(comp_shape)
|
|
228
|
-
_check_value(x_axes, min_val, max_val)
|
|
229
|
-
|
|
230
|
-
# check axis value with input shape - both ways for axis valid
|
|
231
|
-
invalid_a = False
|
|
232
|
-
invalid_b = False
|
|
233
|
-
for i in range(len(axes[0])): # sizes already validated
|
|
234
|
-
if x1_shape[axes[0][i]] != x2_shape[axes[1][i]]:
|
|
235
|
-
invalid_a = True
|
|
236
|
-
if x1_shape[axes[0][i]] != x2_shape[axes[1][len(axes[0]) - 1 - i]]:
|
|
237
|
-
invalid_b = True
|
|
238
|
-
|
|
239
|
-
def _check(invalid_a, invalid_b, x1_shape, x2_shape, axes):
|
|
240
|
-
if invalid_a and invalid_b:
|
|
241
|
-
raise ValueError(f"{msg_prefix} 'i' should exist such that 'x1_shape[axes[0][i]]' is equal to "
|
|
242
|
-
f"'x2_shape[axes[1][i]]' or 'x2_shape[axes[1][len(axes[0])-1-i]]', but got "
|
|
243
|
-
f"'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}, 'axes': {axes}.")
|
|
244
|
-
|
|
245
|
-
_check(invalid_a, invalid_b, x1_shape, x2_shape, axes)
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
@_primexpr
|
|
249
|
-
def _calc_new_shape(shape, axes, position=0):
|
|
250
|
-
"""
|
|
251
|
-
Calculate transpose and reshape parameters for input transformations,
|
|
252
|
-
'position' refers to whether tensor is first or second in the op.
|
|
253
|
-
"""
|
|
254
|
-
contraction_axes = tuple(i if i >= 0 else i + len(shape) for i in axes[position])
|
|
255
|
-
prod_contraction = 1
|
|
256
|
-
for i in contraction_axes:
|
|
257
|
-
prod_contraction *= shape[i]
|
|
258
|
-
free_axes = tuple(i for i in range(len(shape)) if i not in contraction_axes)
|
|
259
|
-
free_dims = tuple(shape[i] if shape[i] is not None else -1 for i in free_axes)
|
|
260
|
-
prod_free = 1
|
|
261
|
-
for free_dim in free_dims:
|
|
262
|
-
prod_free *= free_dim
|
|
263
|
-
|
|
264
|
-
transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes
|
|
265
|
-
new_shape = (prod_contraction, prod_free) if position else (prod_free, prod_contraction)
|
|
266
|
-
return new_shape, transpose_perm, free_dims
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
def tensor_dot(x1, x2, axes):
|
|
270
|
-
"""
|
|
271
|
-
Computation of Tensor contraction on arbitrary axes between tensors `a` and `b`.
|
|
272
|
-
|
|
273
|
-
Contraction allows for the summation of products of elements of `a` and `b` on specified axes.
|
|
274
|
-
The same number of axes must be specified for both x1 and x2, and values must be within range
|
|
275
|
-
of number of dims of both `a` and `b`.
|
|
276
|
-
|
|
277
|
-
Selected dims in both inputs must also match.
|
|
278
|
-
|
|
279
|
-
axes = 0 leads to outer product.
|
|
280
|
-
axes = 1 leads to normal matrix multiplication when inputs both 2D.
|
|
281
|
-
axes = 1 is the same as axes = ((1,),(0,)) where both `a` and `b` are 2D.
|
|
282
|
-
axes = 2 is the same as axes = ((1,2),(0,1)) where both `a` and `b` are 3D.
|
|
283
|
-
|
|
284
|
-
Args:
|
|
285
|
-
x1 (Tensor): First tensor in tensor_dot with datatype float16 or float32
|
|
286
|
-
x2 (Tensor): Second tensor in tensor_dot with datatype float16 or float32
|
|
287
|
-
axes (Union[int, tuple(int), tuple(tuple(int)), list(list(int))]): Single value or
|
|
288
|
-
tuple/list of length 2 with dimensions specified for `a` and `b` each. If single value `N` passed,
|
|
289
|
-
automatically picks up last N dims from `a` input shape and first N dims from `b` input shape in order
|
|
290
|
-
as axes for each respectively.
|
|
291
|
-
|
|
292
|
-
Returns:
|
|
293
|
-
Tensor, the shape of the output tensor is :math:`(N + M)`. Where :math:`N` and :math:`M` are the free axes not
|
|
294
|
-
contracted in both inputs
|
|
295
|
-
|
|
296
|
-
Raises:
|
|
297
|
-
TypeError: If `x1` or `x2` is not a Tensor.
|
|
298
|
-
TypeError: If `axes` is not one of the following: int, tuple, list.
|
|
299
|
-
|
|
300
|
-
Supported Platforms:
|
|
301
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
302
|
-
|
|
303
|
-
Examples:
|
|
304
|
-
>>> from mindspore import Tensor, ops
|
|
305
|
-
>>> import mindspore
|
|
306
|
-
>>> import numpy as np
|
|
307
|
-
>>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
|
|
308
|
-
>>> input_x2 = Tensor(np.ones(shape=[3, 1, 2]), mindspore.float32)
|
|
309
|
-
>>> output = ops.tensor_dot(input_x1, input_x2, ((0,1),(1,2)))
|
|
310
|
-
>>> print(output)
|
|
311
|
-
[[2. 2. 2]
|
|
312
|
-
[2. 2. 2]
|
|
313
|
-
[2. 2. 2]]
|
|
314
|
-
"""
|
|
315
|
-
shape_op = P.Shape()
|
|
316
|
-
reshape_op = P.Reshape()
|
|
317
|
-
transpose_op = P.Transpose()
|
|
318
|
-
matmul_op = P.MatMul(False, False)
|
|
319
|
-
# input validity checks
|
|
320
|
-
x1_shape = shape_op(x1)
|
|
321
|
-
x2_shape = shape_op(x2)
|
|
322
|
-
axes = _check_axes(axes, 'tensor_dot')
|
|
323
|
-
# input compatibility check & axes format update
|
|
324
|
-
axes = _axes_int_check(x1_shape, x2_shape, axes, 'tensor_dot')
|
|
325
|
-
_validate_axes(x1_shape, x2_shape, axes, 'tensor_dot')
|
|
326
|
-
x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape(x1_shape, axes, 0)
|
|
327
|
-
x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape(x2_shape, axes, 1)
|
|
328
|
-
output_shape = x1_ret + x2_ret # combine free axes from both inputs
|
|
329
|
-
# run tensor_dot op
|
|
330
|
-
x1_transposed = transpose_op(x1, x1_transpose_fwd)
|
|
331
|
-
x2_transposed = transpose_op(x2, x2_transpose_fwd)
|
|
332
|
-
x1_reshaped = reshape_op(x1_transposed, x1_reshape_fwd)
|
|
333
|
-
x2_reshaped = reshape_op(x2_transposed, x2_reshape_fwd)
|
|
334
|
-
mul_result = matmul_op(x1_reshaped, x2_reshaped)
|
|
335
|
-
final_result = reshape_op(mul_result, output_shape)
|
|
336
|
-
return final_result
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
@_primexpr
|
|
340
|
-
def _check_invalid_input(x1_shape, x2_shape, prim_name=None):
|
|
341
|
-
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
342
|
-
if len(x1_shape) < 2 or len(x2_shape) < 2:
|
|
343
|
-
raise ValueError(f"{msg_prefix} inputs x1, x2 should have 'dimension >= 2',"
|
|
344
|
-
f"but got 'len(x1_shape)': ({len(x1_shape)}) and 'len(x2_shape)': ({len(x2_shape)}).")
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
@constexpr
|
|
348
|
-
def _typecheck_input_dot(x1_type, x2_type, prim_name=None):
|
|
349
|
-
"""
|
|
350
|
-
Check input tensor types to be valid and confirm they are the same type for dot and batch dot ops.
|
|
351
|
-
"""
|
|
352
|
-
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
353
|
-
const_utils.check_type_valid(x1_type, [mstype.float16, mstype.float32], 'x1')
|
|
354
|
-
const_utils.check_type_valid(x2_type, [mstype.float16, mstype.float32], 'x2')
|
|
355
|
-
if x1_type != x2_type:
|
|
356
|
-
raise TypeError(f"{msg_prefix} inputs must be the same type, but got "
|
|
357
|
-
f"x1_type: {x1_type} and x2_type: {x2_type}.")
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
@_primexpr
|
|
361
|
-
def _get_transpose_shape(x2_shape):
|
|
362
|
-
x2_shape_range = tuple(range(len(x2_shape)))
|
|
363
|
-
x2_shape_transpose = x2_shape_range[-2:-1] + x2_shape_range[:-2] + x2_shape_range[-1:]
|
|
364
|
-
return x2_shape_transpose
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
def dot(input, other):
|
|
368
|
-
"""
|
|
369
|
-
Computation a dot product between samples in two tensors.
|
|
370
|
-
|
|
371
|
-
Args:
|
|
372
|
-
input (Tensor): First tensor in Dot op with datatype float16 or float32,
|
|
373
|
-
The rank must be greater than or equal to 2.
|
|
374
|
-
other (Tensor): Second tensor in Dot op with datatype float16 or float32,
|
|
375
|
-
The rank must be greater than or equal to 2.
|
|
376
|
-
|
|
377
|
-
Returns:
|
|
378
|
-
Tensor, dot product of input and other.
|
|
379
|
-
|
|
380
|
-
Raises:
|
|
381
|
-
TypeError: If type of input and other are not the same.
|
|
382
|
-
TypeError: If dtype of input or other is not float16 or float32.
|
|
383
|
-
ValueError: If rank of input or other less than 2.
|
|
384
|
-
|
|
385
|
-
Supported Platforms:
|
|
386
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
387
|
-
|
|
388
|
-
Examples:
|
|
389
|
-
>>> import numpy as np
|
|
390
|
-
>>> import mindspore
|
|
391
|
-
>>> from mindspore import Tensor, ops
|
|
392
|
-
>>> input = Tensor(np.ones(shape=[2, 3]), mindspore.float32)
|
|
393
|
-
>>> other = Tensor(np.ones(shape=[1, 3, 2]), mindspore.float32)
|
|
394
|
-
>>> output = ops.dot(input, other)
|
|
395
|
-
>>> print(output)
|
|
396
|
-
[[[3. 3.]]
|
|
397
|
-
[[3. 3.]]]
|
|
398
|
-
>>> print(output.shape)
|
|
399
|
-
(2, 1, 2)
|
|
400
|
-
>>> input = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
|
|
401
|
-
>>> other = Tensor(np.ones(shape=[1, 3, 2]), mindspore.float32)
|
|
402
|
-
>>> output = ops.dot(input, other)
|
|
403
|
-
>>> print(output)
|
|
404
|
-
[[[[3. 3.]]
|
|
405
|
-
[[3. 3.]]]]
|
|
406
|
-
>>> print(output.shape)
|
|
407
|
-
(1, 2, 1, 2)
|
|
408
|
-
>>> input = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
|
|
409
|
-
>>> other = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
|
|
410
|
-
>>> output = ops.dot(input, other)
|
|
411
|
-
>>> print(output)
|
|
412
|
-
[[[[3. 3.]
|
|
413
|
-
[3. 3.]]
|
|
414
|
-
[[3. 3.]
|
|
415
|
-
[3. 3.]]]]
|
|
416
|
-
>>> print(output.shape)
|
|
417
|
-
(1, 2, 2, 2)
|
|
418
|
-
>>> input = Tensor(np.ones(shape=[3, 2, 3]), mindspore.float32)
|
|
419
|
-
>>> other = Tensor(np.ones(shape=[2, 1, 3, 2]), mindspore.float32)
|
|
420
|
-
>>> output = ops.dot(input, other)
|
|
421
|
-
>>> print(output)
|
|
422
|
-
[[[[[3. 3.]]
|
|
423
|
-
[[3. 3.]]]
|
|
424
|
-
[[[3. 3.]]
|
|
425
|
-
[[3. 3.]]]]
|
|
426
|
-
[[[[3. 3.]]
|
|
427
|
-
[[3. 3.]]]
|
|
428
|
-
[[[3. 3.]]
|
|
429
|
-
[[3. 3.]]]]
|
|
430
|
-
[[[[3. 3.]]
|
|
431
|
-
[[3. 3.]]]
|
|
432
|
-
[[[3. 3.]]
|
|
433
|
-
[[3. 3.]]]]]
|
|
434
|
-
>>> print(output.shape)
|
|
435
|
-
(3, 2, 2, 1, 2)
|
|
436
|
-
"""
|
|
437
|
-
shape_op = P.Shape()
|
|
438
|
-
reshape_op = P.Reshape()
|
|
439
|
-
transpose_op = P.Transpose()
|
|
440
|
-
matmul_op = P.MatMul(False, False)
|
|
441
|
-
input_shape = shape_op(input)
|
|
442
|
-
other_shape = shape_op(other)
|
|
443
|
-
input_type = F.dtype(input)
|
|
444
|
-
other_type = F.dtype(other)
|
|
445
|
-
_typecheck_input_dot(input_type, other_type, 'dot')
|
|
446
|
-
_check_invalid_input(input_shape, other_shape, 'dot')
|
|
447
|
-
|
|
448
|
-
if len(input_shape) > 2 or len(other_shape) > 2:
|
|
449
|
-
other_shape_transpose = _get_transpose_shape(other_shape)
|
|
450
|
-
other_transpose = transpose_op(other, other_shape_transpose)
|
|
451
|
-
input_reshape = reshape_op(input, (-1, input_shape[-1]))
|
|
452
|
-
other_reshape = reshape_op(other_transpose, (other_shape[-2], -1))
|
|
453
|
-
mul_result = matmul_op(input_reshape, other_reshape)
|
|
454
|
-
reshape_shape = input_shape[:-1] + other_shape[:-2] + other_shape[-1:]
|
|
455
|
-
reshape_shape = (-1,) + reshape_shape[1:]
|
|
456
|
-
return reshape_op(mul_result, reshape_shape)
|
|
457
|
-
return matmul_op(input, other)
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
@_primexpr
|
|
461
|
-
def _get_batch_size(x1_shape, x2_shape, prim_name=None):
|
|
462
|
-
"""
|
|
463
|
-
Get batch sizes from two inputs
|
|
464
|
-
"""
|
|
465
|
-
def _check():
|
|
466
|
-
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
467
|
-
if len(x1_shape) < 2 or len(x2_shape) < 2:
|
|
468
|
-
raise ValueError(f"{msg_prefix} inputs x1, x2 should have 'dimension >= 2', "
|
|
469
|
-
f"but got 'len(x1_shape)': ({len(x1_shape)}) and 'len(x2_shape)': ({len(x2_shape)}).")
|
|
470
|
-
_check()
|
|
471
|
-
return x1_shape[0], x2_shape[0]
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
@constexpr
|
|
475
|
-
def _typecheck_input_batch_dot(x1_type, x2_type, prim_name=None):
|
|
476
|
-
"""
|
|
477
|
-
Check input tensor types to be valid and confirm they are the same type for batch dot ops.
|
|
478
|
-
"""
|
|
479
|
-
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
480
|
-
const_utils.check_type_valid(x1_type, [mstype.float32], 'x1')
|
|
481
|
-
const_utils.check_type_valid(x2_type, [mstype.float32], 'x2')
|
|
482
|
-
if x1_type != x2_type:
|
|
483
|
-
raise TypeError(f"{msg_prefix} inputs must be the same type, but got x1_type: {x1_type} and "
|
|
484
|
-
f"x2_type: {x2_type}.")
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
@_primexpr
|
|
488
|
-
def _check_axes_for_batch_dot(x1_shape, x2_shape, axes, prim_name=None):
|
|
489
|
-
"""
|
|
490
|
-
Check whether axes are valid and cast axes from tuple to list
|
|
491
|
-
"""
|
|
492
|
-
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
493
|
-
|
|
494
|
-
def _check_1(axes):
|
|
495
|
-
if 0 in axes:
|
|
496
|
-
raise ValueError(f"{msg_prefix} 'axes' cannot contain 0, but got axes: {axes}.")
|
|
497
|
-
if len(axes) != 2:
|
|
498
|
-
raise ValueError(f"{msg_prefix} length of 'axes' must be equal to 2, but got {len(axes)}.")
|
|
499
|
-
|
|
500
|
-
def _check_2(axes, x1_shape, x2_shape):
|
|
501
|
-
if axes[0] > len(x1_shape) or axes[1] > len(x2_shape):
|
|
502
|
-
raise ValueError(f"{msg_prefix} axes[0] must be less than or equal to len(x1_shape), "
|
|
503
|
-
f"and axes[1] must be less than or equal to len(x2_shape)."
|
|
504
|
-
f"But got 'axes': {axes}, 'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}.")
|
|
505
|
-
|
|
506
|
-
def _check_3(axes, x1_shape, x2_shape):
|
|
507
|
-
if axes == 0:
|
|
508
|
-
raise ValueError(f"{msg_prefix} 'axes' should not be equal to 0, but got {axes}.")
|
|
509
|
-
|
|
510
|
-
if axes > len(x1_shape) or axes > len(x2_shape):
|
|
511
|
-
raise ValueError(f"{msg_prefix} 'axes' cannot be greater than the length of 'x1_shape' and 'x2_shape', "
|
|
512
|
-
f"but got 'axes': {axes}, 'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}.")
|
|
513
|
-
|
|
514
|
-
if axes is None:
|
|
515
|
-
if len(x2_shape) == 2:
|
|
516
|
-
axes = [len(x1_shape) - 1, len(x2_shape) - 1]
|
|
517
|
-
else:
|
|
518
|
-
axes = [len(x1_shape) - 1, len(x2_shape) - 2]
|
|
519
|
-
|
|
520
|
-
if isinstance(axes, (list, tuple)):
|
|
521
|
-
_check_1(axes)
|
|
522
|
-
if isinstance(axes, tuple):
|
|
523
|
-
axes = list(axes)
|
|
524
|
-
validator.check_value_type('axes[0]', axes[0], [int], 'batch_dot')
|
|
525
|
-
validator.check_value_type('axes[1]', axes[1], [int], 'batch_dot')
|
|
526
|
-
# Reverse if axis < 0
|
|
527
|
-
if axes[0] < 0:
|
|
528
|
-
axes[0] += len(x1_shape)
|
|
529
|
-
if axes[1] < 0:
|
|
530
|
-
axes[1] += len(x2_shape)
|
|
531
|
-
validator.check_non_negative_int(axes[0], 'reversed axes[0]', 'batch_dot')
|
|
532
|
-
validator.check_non_negative_int(axes[1], 'reversed axes[1]', 'batch_dot')
|
|
533
|
-
_check_2(axes, x1_shape, x2_shape)
|
|
534
|
-
elif isinstance(axes, int):
|
|
535
|
-
_check_3(axes, x1_shape, x2_shape)
|
|
536
|
-
if axes < 0:
|
|
537
|
-
axes = [axes + len(x1_shape), axes + len(x2_shape)]
|
|
538
|
-
validator.check_non_negative_int(axes[0], 'reversed axes', 'batch_dot')
|
|
539
|
-
else:
|
|
540
|
-
axes = [axes, axes]
|
|
541
|
-
else:
|
|
542
|
-
raise ValueError(f"{msg_prefix} type of 'axes' must be one of those: int, tuple(int), list(int), "
|
|
543
|
-
f"but got {type(axes).__name__}.")
|
|
544
|
-
return axes
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
@_primexpr
|
|
548
|
-
def _calc_new_shape_batchdot(shape, axes, position=0):
|
|
549
|
-
"""
|
|
550
|
-
Calculate transpose and reshape parameters for input transformations,
|
|
551
|
-
'position' refers to whether tensor is first or second in the op.
|
|
552
|
-
"""
|
|
553
|
-
axis = axes[position]
|
|
554
|
-
contraction_axes = tuple([axis])
|
|
555
|
-
prod_contraction = 1
|
|
556
|
-
for i in contraction_axes:
|
|
557
|
-
prod_contraction *= shape[i]
|
|
558
|
-
free_axes = tuple(i for i in range(1, len(shape)) if i not in contraction_axes)
|
|
559
|
-
free_dims = tuple(shape[i] for i in free_axes)
|
|
560
|
-
prod_free = 1
|
|
561
|
-
for free_dim in free_dims:
|
|
562
|
-
prod_free *= free_dim
|
|
563
|
-
|
|
564
|
-
transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes
|
|
565
|
-
transpose_perm = tuple([0]) + transpose_perm
|
|
566
|
-
new_shape = (prod_contraction, prod_free) if position else (prod_free, prod_contraction)
|
|
567
|
-
new_shape = tuple([shape[0]]) + new_shape
|
|
568
|
-
return new_shape, transpose_perm, free_dims
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
@_primexpr
|
|
572
|
-
def _check_batch_size(x1_batch_size, x2_batch_size, prim_name=None):
|
|
573
|
-
"""
|
|
574
|
-
Check whether batch size of two inputs are the same
|
|
575
|
-
"""
|
|
576
|
-
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
577
|
-
if x1_batch_size != x2_batch_size:
|
|
578
|
-
raise ValueError(f"{msg_prefix} inputs 'x1', 'x2' should have the same batch sizes, but got "
|
|
579
|
-
f"'x1_batch_size': {x1_batch_size} and 'x2_batch_size': {x2_batch_size}.")
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
@_primexpr
|
|
583
|
-
def _get_output_shape(batch_size, x1_ret, x2_ret):
|
|
584
|
-
"""
|
|
585
|
-
Compute output shape for batch dot
|
|
586
|
-
"""
|
|
587
|
-
output_shape = tuple([batch_size]) + x1_ret + x2_ret
|
|
588
|
-
return output_shape
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
def batch_dot(x1, x2, axes=None):
|
|
592
|
-
"""
|
|
593
|
-
Computation of batch dot product between samples in two tensors containing batch dims.
|
|
594
|
-
|
|
595
|
-
.. math::
|
|
596
|
-
output = x1[batch, :] * x2[batch, :]
|
|
597
|
-
|
|
598
|
-
Args:
|
|
599
|
-
x1 (Tensor): First tensor in Batch Dot op with datatype float32 and the rank of `x1` must be greater
|
|
600
|
-
than or equal to 2.
|
|
601
|
-
x2 (Tensor): Second tensor in Batch Dot op with datatype float32. The datatype of `x2` should
|
|
602
|
-
be same as `x1` and the rank of `x2` must be greater than or equal to 2.
|
|
603
|
-
axes (Union[int, tuple(int), list(int)]): Single value or tuple/list of length 2 with dimensions
|
|
604
|
-
specified for `a` and `b` each. If single value `N` passed, automatically picks up last N dims from
|
|
605
|
-
`a` input shape and last N dimensions from `b` input shape in order as axes for each respectively.
|
|
606
|
-
Default: None.
|
|
607
|
-
|
|
608
|
-
Returns:
|
|
609
|
-
Tensor, batch dot product of `x1` and `x2`. For example, the Shape of output
|
|
610
|
-
for input `x1` shapes (batch, d1, axes, d2) and `x2` shapes (batch, d3, axes, d4) is (batch, d1, d2, d3, d4),
|
|
611
|
-
where d1 and d2 means any number.
|
|
612
|
-
|
|
613
|
-
Raises:
|
|
614
|
-
TypeError: If type of x1 and x2 are not the same.
|
|
615
|
-
TypeError: If dtype of x1 or x2 is not float32.
|
|
616
|
-
ValueError: If rank of x1 or x2 less than 2.
|
|
617
|
-
ValueError: If batch dim used in axes.
|
|
618
|
-
ValueError: If len(axes) less than 2.
|
|
619
|
-
ValueError: If axes is not one of those: None, int, (int, int).
|
|
620
|
-
ValueError: If axes reversed from negative int is too low for dimensions of input arrays.
|
|
621
|
-
ValueError: If axes value is too high for dimensions of input arrays.
|
|
622
|
-
ValueError: If batch size of x1 and x2 are not the same.
|
|
623
|
-
|
|
624
|
-
Supported Platforms:
|
|
625
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
626
|
-
|
|
627
|
-
Examples:
|
|
628
|
-
>>> from mindspore import Tensor, ops
|
|
629
|
-
>>> import numpy as np
|
|
630
|
-
>>> x1 = Tensor(np.ones(shape=[2, 2, 3]), mindspore.float32)
|
|
631
|
-
>>> x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
|
|
632
|
-
>>> axes = (-1, -2)
|
|
633
|
-
>>> output = ops.batch_dot(x1, x2, axes)
|
|
634
|
-
>>> print(output)
|
|
635
|
-
[[[3. 3.]
|
|
636
|
-
[3. 3.]]
|
|
637
|
-
[[3. 3.]
|
|
638
|
-
[3. 3.]]]
|
|
639
|
-
>>> x1 = Tensor(np.ones(shape=[2, 2]), mindspore.float32)
|
|
640
|
-
>>> x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
|
|
641
|
-
>>> axes = (1, 2)
|
|
642
|
-
>>> output = ops.batch_dot(x1, x2, axes)
|
|
643
|
-
>>> print(output)
|
|
644
|
-
[[2. 2. 2.]
|
|
645
|
-
[2. 2. 2.]]
|
|
646
|
-
>>> print(output.shape)
|
|
647
|
-
(2, 3)
|
|
648
|
-
>>> x1 = Tensor(np.ones(shape=[6, 2, 3, 4]), mindspore.float32)
|
|
649
|
-
>>> x2 = Tensor(np.ones(shape=[6, 5, 4, 8]), mindspore.float32)
|
|
650
|
-
>>> output = ops.batch_dot(x1, x2)
|
|
651
|
-
>>> print(output.shape)
|
|
652
|
-
(6, 2, 3, 5, 8)
|
|
653
|
-
>>> x1 = Tensor(np.ones(shape=[2, 2, 4]), mindspore.float32)
|
|
654
|
-
>>> x2 = Tensor(np.ones(shape=[2, 5, 4, 5]), mindspore.float32)
|
|
655
|
-
>>> output = ops.batch_dot(x1, x2)
|
|
656
|
-
>>> print(output.shape)
|
|
657
|
-
(2, 2, 5, 5)
|
|
658
|
-
|
|
659
|
-
"""
|
|
660
|
-
|
|
661
|
-
transpose_op = P.Transpose()
|
|
662
|
-
batch_matmul_op = P.BatchMatMul()
|
|
663
|
-
squeeze_one_op = P.Squeeze(1)
|
|
664
|
-
squeeze_minus_one_op = P.Squeeze(-1)
|
|
665
|
-
# input validity checks
|
|
666
|
-
x1_shape = F.shape(x1)
|
|
667
|
-
x2_shape = F.shape(x2)
|
|
668
|
-
x1_dim_num = len(x1_shape)
|
|
669
|
-
x2_dim_num = len(x2_shape)
|
|
670
|
-
x1_type = F.dtype(x1)
|
|
671
|
-
x2_type = F.dtype(x2)
|
|
672
|
-
|
|
673
|
-
x1_batch_size, x2_batch_size = _get_batch_size(x1_shape, x2_shape, 'batch_dot')
|
|
674
|
-
|
|
675
|
-
_typecheck_input_batch_dot(x1_type, x2_type, 'batch_dot')
|
|
676
|
-
_check_batch_size(x1_batch_size, x2_batch_size, 'batch_dot')
|
|
677
|
-
axes = _check_axes_for_batch_dot(x1_shape, x2_shape, axes, 'batch_dot')
|
|
678
|
-
|
|
679
|
-
if x1_dim_num == 2:
|
|
680
|
-
x1 = F.expand_dims(x1, 1)
|
|
681
|
-
axes[0] += 1
|
|
682
|
-
if x2_dim_num == 2:
|
|
683
|
-
x2 = F.expand_dims(x2, 2)
|
|
684
|
-
|
|
685
|
-
x1_shape = F.shape(x1)
|
|
686
|
-
x2_shape = F.shape(x2)
|
|
687
|
-
|
|
688
|
-
x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape_batchdot(x1_shape, axes, 0)
|
|
689
|
-
x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape_batchdot(x2_shape, axes, 1)
|
|
690
|
-
output_shape = _get_output_shape(x1_batch_size, x1_ret, x2_ret)
|
|
691
|
-
|
|
692
|
-
x1_transposed = transpose_op(x1, x1_transpose_fwd)
|
|
693
|
-
x2_transposed = transpose_op(x2, x2_transpose_fwd)
|
|
694
|
-
x1_reshaped = F.reshape(x1_transposed, x1_reshape_fwd)
|
|
695
|
-
x2_reshaped = F.reshape(x2_transposed, x2_reshape_fwd)
|
|
696
|
-
|
|
697
|
-
# Batch matmal op part
|
|
698
|
-
mul_result = batch_matmul_op(x1_reshaped, x2_reshaped)
|
|
699
|
-
|
|
700
|
-
final_result = F.reshape(mul_result, output_shape)
|
|
701
|
-
|
|
702
|
-
# if the original dims are expanded, restore them from 3 to 2
|
|
703
|
-
if x1_dim_num == 2:
|
|
704
|
-
final_result = squeeze_one_op(final_result)
|
|
705
|
-
elif x2_dim_num == 2:
|
|
706
|
-
final_result = squeeze_minus_one_op(final_result)
|
|
707
|
-
|
|
708
|
-
return final_result
|
|
19
|
+
from mindspore.ops._primitive_cache import _get_cache_prim
|
|
709
20
|
|
|
710
21
|
|
|
711
22
|
def matmul(x1, x2, dtype=None):
|
|
@@ -808,10 +119,9 @@ def mm(input, mat2):
|
|
|
808
119
|
>>> print(out.shape)
|
|
809
120
|
(2, 4)
|
|
810
121
|
"""
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
return matmul(input, mat2)
|
|
122
|
+
_matmul = _get_cache_prim(ops.MatMul)()
|
|
123
|
+
out = _matmul(input, mat2)
|
|
124
|
+
return out
|
|
815
125
|
|
|
816
126
|
|
|
817
127
|
def cummin(x, axis):
|