mindspore 2.0.0rc1__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu10.1/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +647 -818
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
|
@@ -24,7 +24,8 @@ from mindspore.ops import operations as P
|
|
|
24
24
|
from mindspore.ops.composite import base
|
|
25
25
|
from mindspore.ops._primitive_cache import _get_cache_prim
|
|
26
26
|
from mindspore.ops.operations._inner_ops import TensorCopySlices, SliceGetItem, \
|
|
27
|
-
TopTypeof, issubclass_, IsParameter, GetitemTensorIndexInfo, SetitemTensorIndexInfo
|
|
27
|
+
TopTypeof, issubclass_, IsParameter, GetitemTensorIndexInfo, SetitemTensorIndexInfo, \
|
|
28
|
+
SelectView, CopyWithSlice
|
|
28
29
|
from mindspore.common import dtype as mstype
|
|
29
30
|
from mindspore.common._register_for_tensor import tensor_operator_registry
|
|
30
31
|
from mindspore.common.initializer import Zero
|
|
@@ -32,6 +33,8 @@ from mindspore.common import Tensor, CSRTensor, COOTensor
|
|
|
32
33
|
from mindspore.common import mutable
|
|
33
34
|
from mindspore import ops
|
|
34
35
|
from mindspore.ops.primitive import _primexpr
|
|
36
|
+
from mindspore import _checkparam as validator
|
|
37
|
+
from mindspore.common._stub_tensor import _convert_stub
|
|
35
38
|
|
|
36
39
|
slice_get_item = SliceGetItem()
|
|
37
40
|
hyper_map = base.HyperMap()
|
|
@@ -42,6 +45,8 @@ is_parameter = IsParameter()
|
|
|
42
45
|
getitem_tensor_index_info = GetitemTensorIndexInfo(const_utils.is_ascend())
|
|
43
46
|
setitem_tensor_index_info = SetitemTensorIndexInfo(const_utils.is_ascend())
|
|
44
47
|
|
|
48
|
+
selevt_view = SelectView()
|
|
49
|
+
copy_with_slice = CopyWithSlice()
|
|
45
50
|
|
|
46
51
|
def strided_slice(data, begin_strides, end_strides, step_strides, begin_mask=0, end_mask=0, ellipsis_mask=0,
|
|
47
52
|
new_axis_mask=0, shrink_axis_mask=0):
|
|
@@ -65,16 +70,23 @@ class ValueTransferType(IntEnum):
|
|
|
65
70
|
kGatherND = 9
|
|
66
71
|
kScatterNdUpdate = 10
|
|
67
72
|
kReshape = 11
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
73
|
+
kSelectView = 12
|
|
74
|
+
kUnsqueeze = 13
|
|
75
|
+
kCopyView = 14
|
|
76
|
+
kScatterND = 15
|
|
77
|
+
kNumberToTensor = 16
|
|
78
|
+
kHandleSequenceValue = 17
|
|
79
|
+
kByPass = 18
|
|
80
|
+
kReSetItemByIndex = 19
|
|
81
|
+
kCopySlice = 20
|
|
82
|
+
kSetItemByBool = 21
|
|
83
|
+
kEmptyTensor = 22
|
|
84
|
+
kSetItemByEllipsis = 23
|
|
85
|
+
kFormatIndexTensor = 24
|
|
86
|
+
kGetitemByBoolTensor = 25
|
|
87
|
+
kSetitemByBoolTensor = 26
|
|
88
|
+
kJustReturn = 27
|
|
89
|
+
kRaiseIndexError = 28
|
|
78
90
|
|
|
79
91
|
|
|
80
92
|
def data_update(transfer_types, args, data, new_index, value=None):
|
|
@@ -82,11 +94,14 @@ def data_update(transfer_types, args, data, new_index, value=None):
|
|
|
82
94
|
We finally generate a new tensor when handling tensor getitem/setitem
|
|
83
95
|
by transfer data and value with index.
|
|
84
96
|
"""
|
|
97
|
+
origin_data = data
|
|
85
98
|
for transfer_type, arg in zip(transfer_types, args):
|
|
86
99
|
if transfer_type == ValueTransferType.kUnknown:
|
|
87
100
|
raise IndexError(f"Inlvaid transfer type {transfer_type}.")
|
|
88
101
|
if transfer_type <= ValueTransferType.kScatterND:
|
|
89
|
-
data = data_update_by_ops(transfer_type, arg, data, new_index, value)
|
|
102
|
+
data = data_update_by_ops(transfer_type, arg, data, new_index, origin_data, value)
|
|
103
|
+
if transfer_type == ValueTransferType.kJustReturn:
|
|
104
|
+
return _convert_stub(arg)
|
|
90
105
|
if transfer_type == ValueTransferType.kSetItemByBool:
|
|
91
106
|
return tensor_setitem_by_bool(data, new_index, value)
|
|
92
107
|
if transfer_type == ValueTransferType.kCopySlice:
|
|
@@ -98,13 +113,19 @@ def data_update(transfer_types, args, data, new_index, value=None):
|
|
|
98
113
|
return data
|
|
99
114
|
if transfer_type == ValueTransferType.kEmptyTensor:
|
|
100
115
|
return handle_empty_tensor(arg, data)
|
|
116
|
+
if transfer_type == ValueTransferType.kFormatIndexTensor:
|
|
117
|
+
new_index = format_index_tensor(new_index, arg)
|
|
118
|
+
if transfer_type == ValueTransferType.kGetitemByBoolTensor:
|
|
119
|
+
return F.gather_nd(data, new_index.nonzero())
|
|
120
|
+
if transfer_type == ValueTransferType.kSetitemByBoolTensor:
|
|
121
|
+
return handle_setitem_by_bool_tensor(data, new_index, value)
|
|
101
122
|
if transfer_type == ValueTransferType.kRaiseIndexError:
|
|
102
123
|
raise IndexError(
|
|
103
124
|
f'index {arg[0]} is out of bounds for dimension with size {arg[1]}')
|
|
104
125
|
return data
|
|
105
126
|
|
|
106
127
|
|
|
107
|
-
def data_update_by_ops(transfer_type, arg, data, new_index, value=None):
|
|
128
|
+
def data_update_by_ops(transfer_type, arg, data, new_index, origin_data, value=None):
|
|
108
129
|
"""
|
|
109
130
|
Generate a new tensor when handling tensor getitem/setitem
|
|
110
131
|
by ops.
|
|
@@ -125,14 +146,22 @@ def data_update_by_ops(transfer_type, arg, data, new_index, value=None):
|
|
|
125
146
|
F.scatter_nd_update(data, new_index, value)
|
|
126
147
|
elif transfer_type == ValueTransferType.kSelect:
|
|
127
148
|
data = F.select(Tensor(new_index), value, data)
|
|
149
|
+
elif transfer_type == ValueTransferType.kSelectView:
|
|
150
|
+
data = selevt_view(data, arg[0], arg[1])
|
|
151
|
+
elif transfer_type == ValueTransferType.kCopyView:
|
|
152
|
+
value = _broadcast(F.shape(data), F.cast(value, F.dtype(data)))
|
|
153
|
+
data = copy_with_slice(data, value)
|
|
154
|
+
return origin_data
|
|
128
155
|
elif transfer_type == ValueTransferType.kReshape:
|
|
129
156
|
data = F.reshape(data, arg)
|
|
130
157
|
elif transfer_type == ValueTransferType.kGather:
|
|
131
158
|
data = F.gather(data, new_index, 0)
|
|
132
159
|
elif transfer_type == ValueTransferType.kExpandDims:
|
|
133
160
|
data = F.expand_dims(data, 0)
|
|
161
|
+
elif transfer_type == ValueTransferType.kUnsqueeze:
|
|
162
|
+
data = F.unsqueeze(data, arg)
|
|
134
163
|
elif transfer_type == ValueTransferType.kStrideSlice:
|
|
135
|
-
data =
|
|
164
|
+
data = strided_slice(data, arg[0], arg[1], arg[2])
|
|
136
165
|
else:
|
|
137
166
|
raise IndexError(f"Inlvaid transfer type {transfer_type}.")
|
|
138
167
|
return data
|
|
@@ -144,7 +173,7 @@ def value_update(transfer_types, args, data, value):
|
|
|
144
173
|
if transfer_type == ValueTransferType.kByPass:
|
|
145
174
|
continue
|
|
146
175
|
if transfer_type == ValueTransferType.kNumberToTensor:
|
|
147
|
-
value = F.
|
|
176
|
+
value = F.cast(value, F.dtype(data))
|
|
148
177
|
elif transfer_type == ValueTransferType.kHandleSequenceValue:
|
|
149
178
|
op_type, index = arg
|
|
150
179
|
if op_type == const_utils.SET_ITEM_BY_ONE_TENSOR:
|
|
@@ -182,7 +211,10 @@ def _tensor_setitem(self, index, value):
|
|
|
182
211
|
data_update_types = setitem_info[3]
|
|
183
212
|
data_update_args = setitem_info[4]
|
|
184
213
|
value = value_update(v_transfer_types, v_transfer_args, self, value)
|
|
185
|
-
|
|
214
|
+
output = data_update(data_update_types, data_update_args, self, new_index, value)
|
|
215
|
+
if new_index == "view":
|
|
216
|
+
return (self,)
|
|
217
|
+
return output
|
|
186
218
|
|
|
187
219
|
|
|
188
220
|
tensor_operator_registry.register("__getitem__", _tensor_getitem)
|
|
@@ -273,17 +305,27 @@ def _scalar_to_tensor(input_x):
|
|
|
273
305
|
return ops.add(input_x, mutable(Tensor(0)))
|
|
274
306
|
|
|
275
307
|
|
|
308
|
+
@_primexpr
|
|
309
|
+
def _check_scalar_tensor_args(args):
|
|
310
|
+
"""For the item, check that the index of the scalar tensor is set."""
|
|
311
|
+
if args not in ((None,), ()):
|
|
312
|
+
const_utils.raise_value_error("For item, the index of scalar Tensor should not be set.")
|
|
313
|
+
|
|
314
|
+
|
|
276
315
|
def tensor_item(data, *args):
|
|
277
316
|
"""Tensor getitem by index whose dtype is int or tuple with int."""
|
|
278
317
|
# transform a.item(tuple(int)) -> a.item(int1,int2...intN)
|
|
318
|
+
if data.ndim == 0:
|
|
319
|
+
_check_scalar_tensor_args(args)
|
|
320
|
+
return data.asnumpy().item()
|
|
279
321
|
if len(args) == 1 and isinstance(args[0], tuple):
|
|
280
322
|
args = args[0]
|
|
281
323
|
|
|
282
324
|
args_types = hyper_map(F.typeof, args)
|
|
283
325
|
if not args or const_utils.judge_index_type(args_types[0], mstype.type_none):
|
|
284
326
|
if data.shape == (1,):
|
|
285
|
-
return data
|
|
286
|
-
const_utils.raise_value_error("Can only convert an array of size 1 to a
|
|
327
|
+
return data.asnumpy().item()
|
|
328
|
+
const_utils.raise_value_error("Can only convert an array of size 1 to a Python scalar")
|
|
287
329
|
|
|
288
330
|
if not const_utils.judge_indexes_types(args_types, mstype.int64):
|
|
289
331
|
const_utils.raise_type_error("The index object cannot be interpreted as an integer")
|
|
@@ -342,7 +384,8 @@ def tensor_itemset_by_tuple_with_number(data, tuple_index, nubmer_value):
|
|
|
342
384
|
exp_msg = const_utils.gen_exception_msg(
|
|
343
385
|
"Tuple index len({}) is not same to tensor dimension({})", len(tuple_index), data.ndim)
|
|
344
386
|
const_utils.raise_index_error(exp_msg)
|
|
345
|
-
|
|
387
|
+
nubmer_value = F.cast(nubmer_value, F.dtype(data))
|
|
388
|
+
return tensor_itemset_by_tuple_with_tensor(data, tuple_index, nubmer_value)
|
|
346
389
|
|
|
347
390
|
|
|
348
391
|
def _broadcast(broadcast_shape, x):
|
|
@@ -429,12 +472,39 @@ def handle_multi_dim_index_tensor(new_index, arg):
|
|
|
429
472
|
return new_index
|
|
430
473
|
|
|
431
474
|
|
|
475
|
+
def format_index_tensor(index, arg):
|
|
476
|
+
"""Format index tensor when tensor less than 0"""
|
|
477
|
+
format_indices, format_dims = arg
|
|
478
|
+
if isinstance(index, list):
|
|
479
|
+
for format_idx, format_dim in zip(format_indices, format_dims):
|
|
480
|
+
index_tensor = index[format_idx]
|
|
481
|
+
index[format_idx] = F.select(index_tensor < 0, index_tensor + format_dim, index_tensor)
|
|
482
|
+
return index
|
|
483
|
+
index = Tensor(index)
|
|
484
|
+
return F.select(index < 0, index + format_dims, index)
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
def handle_setitem_by_bool_tensor(data, index, value):
|
|
488
|
+
"""Set a tensor item by a bool tensor with a tensor."""
|
|
489
|
+
value = F.cast(value, F.dtype(data))
|
|
490
|
+
indices = index.nonzero()
|
|
491
|
+
if indices.shape[0] == 0:
|
|
492
|
+
return data
|
|
493
|
+
value_shape = (indices.shape[0],) + data.shape[index.ndim:]
|
|
494
|
+
value = _broadcast(value_shape, value)
|
|
495
|
+
value = F.scatter_nd(indices, value, data.shape)
|
|
496
|
+
index = index.reshape(const_utils.generate_padding_shape(index.shape, len(data.shape)))
|
|
497
|
+
index = _broadcast(data.shape, index)
|
|
498
|
+
result = F.select(index, value, data)
|
|
499
|
+
return result
|
|
500
|
+
|
|
501
|
+
|
|
432
502
|
def _expand_data_dims(data, tuple_index):
|
|
433
503
|
"""expand the data's dim with 'None' and 'Boolean' in tuple_index"""
|
|
434
504
|
indexes_types = hyper_map(toptypeof, tuple_index)
|
|
435
505
|
expand_positions, tuple_index_new = (), ()
|
|
436
506
|
for i, (index, index_type) in enumerate(zip(tuple_index, indexes_types)):
|
|
437
|
-
if isinstance(index_type, mstype.
|
|
507
|
+
if isinstance(index_type, mstype.NoneType):
|
|
438
508
|
tuple_index_new += (const_utils.make_empty_slice(),)
|
|
439
509
|
expand_positions += (i,)
|
|
440
510
|
elif isinstance(index_type, mstype.Bool):
|
|
@@ -471,29 +541,27 @@ def convert_variable_to_tensor_slice(slice_index):
|
|
|
471
541
|
return slice_index
|
|
472
542
|
|
|
473
543
|
|
|
544
|
+
class _TensorIndexGetitem(base.TensorIndexGetitem_):
|
|
545
|
+
"""
|
|
546
|
+
Getting item of Tensor.
|
|
547
|
+
|
|
548
|
+
Args:
|
|
549
|
+
data (Tensor): A tuple to be sliced.
|
|
550
|
+
index: Index of tensor.
|
|
551
|
+
|
|
552
|
+
Returns:
|
|
553
|
+
Type is the same as the element type of data.
|
|
554
|
+
"""
|
|
555
|
+
|
|
556
|
+
def __call__(self, *args):
|
|
557
|
+
pass
|
|
558
|
+
|
|
559
|
+
_tensor_index_getitem = _TensorIndexGetitem('tensor_index_getitem')
|
|
560
|
+
|
|
561
|
+
|
|
474
562
|
def tensor_index_by_slice(data, slice_index):
|
|
475
563
|
"""Tensor getitem by a slice."""
|
|
476
|
-
|
|
477
|
-
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
478
|
-
data_shape = F.shape(data)
|
|
479
|
-
slice_index = convert_variable_to_tensor_slice(slice_index)
|
|
480
|
-
|
|
481
|
-
is_dynamic = (F.is_sequence_value_unknown(data_shape)
|
|
482
|
-
or isinstance(slice_get_item(slice_index, "start"), Tensor)
|
|
483
|
-
or isinstance(slice_get_item(slice_index, "stop"), Tensor)
|
|
484
|
-
or isinstance(slice_get_item(slice_index, "step"), Tensor))
|
|
485
|
-
if is_dynamic:
|
|
486
|
-
begin_strides, end_strides, step_strides = get_stride_info_from_slice(data, slice_index)
|
|
487
|
-
else:
|
|
488
|
-
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(data_shape, slice_index)
|
|
489
|
-
begin_mask = 1 if slice_get_item(slice_index, "start") is None else 0
|
|
490
|
-
end_mask = 1 if slice_get_item(slice_index, "stop") is None else 0
|
|
491
|
-
for i in range(1, len(data_shape)):
|
|
492
|
-
begin_mask += 2 ** i
|
|
493
|
-
end_mask += 2 ** i
|
|
494
|
-
if begin_mask or end_mask:
|
|
495
|
-
return strided_slice(data, begin_strides, end_strides, step_strides, begin_mask, end_mask, 0, 0, 0)
|
|
496
|
-
return F.strided_slice(data, begin_strides, end_strides, step_strides)
|
|
564
|
+
return _tensor_index_getitem(data, slice_index)
|
|
497
565
|
|
|
498
566
|
|
|
499
567
|
def get_stride_info_from_slice(data, slice_index):
|
|
@@ -531,9 +599,12 @@ def _tensor_index_by_bool(data, bool_value):
|
|
|
531
599
|
"""Tensor getitem by a single bool value"""
|
|
532
600
|
min_data_dim, max_data_dim = 0, 7
|
|
533
601
|
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
602
|
+
output = data
|
|
534
603
|
if bool_value:
|
|
535
|
-
|
|
536
|
-
|
|
604
|
+
output = F.expand_dims(data, 0)
|
|
605
|
+
elif not F.is_sequence_value_unknown(F.shape(data)):
|
|
606
|
+
return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.")
|
|
607
|
+
return output
|
|
537
608
|
|
|
538
609
|
|
|
539
610
|
def get_stride_info_from_integer(tensor_int):
|
|
@@ -550,15 +621,14 @@ def get_stride_info_from_integer(tensor_int):
|
|
|
550
621
|
def _tensor_index_by_integer(data, int_index):
|
|
551
622
|
"""Tensor getitem by a single integer number"""
|
|
552
623
|
data_shape = F.shape(data)
|
|
553
|
-
if not data_shape:
|
|
554
|
-
const_utils.raise_type_error("Cannot iterate over a scalar tensor.")
|
|
555
|
-
if data.ndim < 1 or data.ndim > 8:
|
|
556
|
-
const_utils.raise_value_error("Expect Tensor to have dimension between 1 and 8.")
|
|
557
|
-
|
|
558
624
|
if F.is_sequence_value_unknown(data_shape) or not F.isconstant(int_index):
|
|
559
625
|
tensor_index = _scalar_to_tensor(int_index)
|
|
560
626
|
begin_strides, end_strides, step_strides = get_stride_info_from_integer(tensor_index)
|
|
561
627
|
else:
|
|
628
|
+
if not data_shape:
|
|
629
|
+
const_utils.raise_type_error("Cannot iterate over a scalar tensor.")
|
|
630
|
+
if data.ndim < 1 or data.ndim > 8:
|
|
631
|
+
const_utils.raise_value_error("Expect Tensor to have dimension between 1 and 8.")
|
|
562
632
|
transformed_number = const_utils.check_range(int_index, data_shape[0])
|
|
563
633
|
begin_strides, end_strides, step_strides = \
|
|
564
634
|
const_utils.get_stride_info_from_integer(data_shape, transformed_number)
|
|
@@ -570,7 +640,6 @@ def _tensor_index_by_integer(data, int_index):
|
|
|
570
640
|
end_mask += 2 ** i
|
|
571
641
|
return strided_slice(data, begin_strides, end_strides, step_strides, begin_mask, end_mask, 0, 0, shrink_axis_mask)
|
|
572
642
|
|
|
573
|
-
|
|
574
643
|
def _check_dim_shape_valid(data, tensor_index):
|
|
575
644
|
"""check dim and shape of tensor_index for tensor(bool) indexing"""
|
|
576
645
|
if data.ndim < tensor_index.ndim:
|
|
@@ -583,7 +652,8 @@ def _check_dim_shape_valid(data, tensor_index):
|
|
|
583
652
|
|
|
584
653
|
def tensor_index_by_bool_tensor(data, tensor_index):
|
|
585
654
|
"""Tensor getitem by a bool tensor"""
|
|
586
|
-
|
|
655
|
+
if not F.is_sequence_value_unknown(F.shape(data)):
|
|
656
|
+
_check_dim_shape_valid(data, tensor_index)
|
|
587
657
|
tensor_index = tensor_index.nonzero()
|
|
588
658
|
return F.gather_nd(data, tensor_index)
|
|
589
659
|
|
|
@@ -591,7 +661,8 @@ def tensor_index_by_bool_tensor(data, tensor_index):
|
|
|
591
661
|
def tensor_index_by_tensor(data, tensor_index):
|
|
592
662
|
"""Tensor getitem by a single tensor"""
|
|
593
663
|
min_data_dim, max_data_dim = 0, 7
|
|
594
|
-
|
|
664
|
+
if not F.is_sequence_value_unknown(F.shape(data)):
|
|
665
|
+
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
595
666
|
if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Int):
|
|
596
667
|
return F.gather(data, tensor_index, 0)
|
|
597
668
|
if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Bool):
|
|
@@ -609,16 +680,22 @@ def tensor_index_by_list(data, list_index):
|
|
|
609
680
|
|
|
610
681
|
data_shape = F.shape(data)
|
|
611
682
|
indexes_types = hyper_map(toptypeof, list_index)
|
|
612
|
-
if const_utils.check_type_isinstance(indexes_types, (mstype.Bool, mstype.Int))
|
|
683
|
+
if const_utils.check_type_isinstance(indexes_types, (mstype.Bool, mstype.Int)) \
|
|
684
|
+
and not F.is_sequence_value_unknown(list_index):
|
|
613
685
|
if not F.isconstant(data_shape[0]):
|
|
614
686
|
if all(isinstance(i, bool) for i in list_index):
|
|
615
|
-
|
|
616
|
-
|
|
687
|
+
if F.dyn_shape(data)[0] != len(list_index):
|
|
688
|
+
raise IndexError(
|
|
689
|
+
f'dimension is {F.dyn_shape(data)[0]} but corresponding boolean dimension is {len(list_index)}')
|
|
690
|
+
tensor_index = Tensor(list_index).nonzero()
|
|
691
|
+
return F.gather_nd(data, tensor_index)
|
|
617
692
|
tensor_index = const_utils.sequence_to_index(list_index, None)
|
|
618
693
|
else:
|
|
619
|
-
tensor_index = const_utils.sequence_to_index(
|
|
694
|
+
tensor_index = const_utils.sequence_to_index(
|
|
695
|
+
list_index, data_shape[0])
|
|
620
696
|
if tensor_index is False:
|
|
621
|
-
const_utils.raise_index_error(
|
|
697
|
+
const_utils.raise_index_error(
|
|
698
|
+
"When tensor is indexed by list, the list can't be empty.")
|
|
622
699
|
return F.gather(data, tensor_index, 0)
|
|
623
700
|
|
|
624
701
|
tuple_index_new = ()
|
|
@@ -637,23 +714,92 @@ def convert_tupleslice_to_tensor(tuple_index):
|
|
|
637
714
|
return tuple(new_tuple_index)
|
|
638
715
|
|
|
639
716
|
|
|
640
|
-
def
|
|
641
|
-
"""
|
|
642
|
-
if
|
|
643
|
-
|
|
717
|
+
def judge_tuple_index_dim_check_error(index_dim, data_dim):
|
|
718
|
+
"""raise IndexError when tuple_index's dim is invalid"""
|
|
719
|
+
if index_dim > data_dim:
|
|
720
|
+
raise IndexError(f"The dim of index cannot be greater than indexed data, but got "
|
|
721
|
+
f"dim of index:{index_dim}, dim of data:{data_dim}")
|
|
644
722
|
|
|
645
|
-
tuple_index = convert_tupleslice_to_tensor(tuple_index)
|
|
646
|
-
op_name = const_utils.TENSOR_GETITEM
|
|
647
|
-
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
|
648
|
-
data, tuple_index = _expand_data_dims(data, tuple_index)
|
|
649
723
|
|
|
650
|
-
|
|
651
|
-
|
|
724
|
+
class _HandleEmptySlice(base.HandleEmptySlice_):
|
|
725
|
+
"""
|
|
726
|
+
Getting item of Tensor.
|
|
727
|
+
|
|
728
|
+
Args:
|
|
729
|
+
data (Tensor): A tuple to be sliced.
|
|
730
|
+
index: Index of tensor.
|
|
731
|
+
|
|
732
|
+
Returns:
|
|
733
|
+
Type is the same as the element type of data.
|
|
734
|
+
"""
|
|
735
|
+
|
|
736
|
+
def __init__(self, name):
|
|
737
|
+
"""Initialize _HandleEmptySlice."""
|
|
738
|
+
base.HandleEmptySlice_.__init__(self, name)
|
|
739
|
+
|
|
740
|
+
def __call__(self, *args):
|
|
741
|
+
pass
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
_handle_empty_slice = _HandleEmptySlice('handle_zero_tuple_index')
|
|
745
|
+
|
|
746
|
+
|
|
747
|
+
def judge_tuple_index_dim(data, tuple_index):
|
|
748
|
+
"""Judge whether tuple_index's dim is valid"""
|
|
749
|
+
data_dim = data.ndim
|
|
750
|
+
index_dim = 0
|
|
751
|
+
for index in tuple_index:
|
|
752
|
+
if isinstance(toptypeof(index), mstype.TensorType) and index.dtype == mstype.bool_:
|
|
753
|
+
index_dim += index.ndim
|
|
754
|
+
elif not isinstance(toptypeof(index), (mstype.NoneType, mstype.Ellipsis_, mstype.Bool)):
|
|
755
|
+
index_dim += 1
|
|
756
|
+
judge_tuple_index_dim_check_error(index_dim, data_dim)
|
|
757
|
+
|
|
758
|
+
|
|
759
|
+
def judge_simple_tuple_index(data, tuple_index):
|
|
760
|
+
"""Judge whether tuple_index is simple index, which not rollback to cpu ops."""
|
|
761
|
+
op_name = const_utils.TENSOR_GETITEM
|
|
652
762
|
indexes_types = hyper_map(toptypeof, tuple_index)
|
|
653
763
|
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
|
|
654
|
-
|
|
764
|
+
return F.isconstant(tuple_index) and contain_type == const_utils.ALL_BASIC \
|
|
765
|
+
and F.is_sequence_value_unknown(F.shape(data)) and F.isconstant(F.rank(data))
|
|
766
|
+
|
|
767
|
+
|
|
768
|
+
def tensor_index_by_tuple(data, tuple_index):
|
|
769
|
+
"""Tensor getitem by tuple of various types with None"""
|
|
770
|
+
if not tuple_index:
|
|
771
|
+
return data
|
|
772
|
+
if judge_simple_tuple_index(data, tuple_index):
|
|
773
|
+
tuple_index = convert_tupleslice_to_tensor(tuple_index)
|
|
774
|
+
op_name = const_utils.TENSOR_GETITEM
|
|
775
|
+
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
|
776
|
+
min_data_dim, max_data_dim = 1, 8
|
|
777
|
+
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
655
778
|
return _tensor_getitem_by_tuple_slice(data, tuple_index)
|
|
656
|
-
|
|
779
|
+
|
|
780
|
+
if not F.is_sequence_value_unknown(F.shape(data)):
|
|
781
|
+
judge_tuple_index_dim(data, tuple_index)
|
|
782
|
+
tuple_index, zero_index, non_zero_shapes = _handle_bool_tensor(tuple_index)
|
|
783
|
+
for non_zero_shape in non_zero_shapes:
|
|
784
|
+
if F.reduce_min(non_zero_shape) == 0:
|
|
785
|
+
tuple_index = zero_index
|
|
786
|
+
break
|
|
787
|
+
if not F.is_sequence_value_unknown(F.shape(data)) and F.isconstant(tuple_index):
|
|
788
|
+
_, stub_zero_dim_tensor = _handle_empty_slice(data, tuple_index)
|
|
789
|
+
if 0 in stub_zero_dim_tensor.shape:
|
|
790
|
+
return F.fill(data.dtype, stub_zero_dim_tensor.shape, 0)
|
|
791
|
+
has_tensor_index = False
|
|
792
|
+
for i in tuple_index:
|
|
793
|
+
if isinstance(i, Tensor):
|
|
794
|
+
has_tensor_index = True
|
|
795
|
+
break
|
|
796
|
+
empty_broadcast_data_shape = False
|
|
797
|
+
_broadcast_data_shape = _handle_scalar_tensor_index(data, tuple_index)
|
|
798
|
+
if has_tensor_index and isinstance(_broadcast_data_shape, Tensor) and _broadcast_data_shape == Tensor([0]):
|
|
799
|
+
empty_broadcast_data_shape = True
|
|
800
|
+
if has_tensor_index and isinstance(_broadcast_data_shape, tuple) and not _broadcast_data_shape:
|
|
801
|
+
empty_broadcast_data_shape = True
|
|
802
|
+
return _tensor_index_getitem(data, tuple_index, empty_broadcast_data_shape)
|
|
657
803
|
|
|
658
804
|
|
|
659
805
|
def get_slice_stride(slice_index, dim_size):
|
|
@@ -895,6 +1041,15 @@ def _generate_indices_from_tuple_of_tensor(tuple_index, op_name):
|
|
|
895
1041
|
return indices
|
|
896
1042
|
|
|
897
1043
|
|
|
1044
|
+
def parse_check_slice_index(index_out, dim_size):
|
|
1045
|
+
""" Parse and check slice index """
|
|
1046
|
+
has_false = False
|
|
1047
|
+
start, stop, step = const_utils.normalize_slice(index_out, dim_size)
|
|
1048
|
+
if F.isconstant(start) and F.isconstant(stop) and F.isconstant(step):
|
|
1049
|
+
has_false = const_utils.check_slice_empty(start, stop, step)
|
|
1050
|
+
return has_false
|
|
1051
|
+
|
|
1052
|
+
|
|
898
1053
|
def _generate_indices_from_tuple(data, tuple_index, op_name, fancy_position):
|
|
899
1054
|
"""Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor."""
|
|
900
1055
|
data_shape = F.shape(data)
|
|
@@ -925,8 +1080,7 @@ def _generate_indices_from_tuple(data, tuple_index, op_name, fancy_position):
|
|
|
925
1080
|
tuple_index_new += (tensor_index,)
|
|
926
1081
|
tensor_indexes.append(tensor_index)
|
|
927
1082
|
elif i in slice_positions:
|
|
928
|
-
|
|
929
|
-
if const_utils.check_slice_empty(start, stop, step):
|
|
1083
|
+
if parse_check_slice_index(index, dim_size):
|
|
930
1084
|
return False
|
|
931
1085
|
slice_ele_list_index = const_utils.transform_slice_to_ele_list(index, dim_size)
|
|
932
1086
|
slice_shapes += (len(slice_ele_list_index),)
|
|
@@ -962,7 +1116,7 @@ def sequence_to_tensor(value, dtype):
|
|
|
962
1116
|
|
|
963
1117
|
if value_elements_type == const_utils.ALL_TENSOR:
|
|
964
1118
|
value = F.stack(value).astype(dtype)
|
|
965
|
-
elif value_elements_type == const_utils.NO_TENSOR:
|
|
1119
|
+
elif value_elements_type == const_utils.NO_TENSOR and not F.is_sequence_value_unknown(value):
|
|
966
1120
|
value = const_utils.make_tensor(value, dtype)
|
|
967
1121
|
else:
|
|
968
1122
|
new_value = ()
|
|
@@ -984,7 +1138,7 @@ def _generate_updates_from_sequence(data, index, value, op_type):
|
|
|
984
1138
|
def _generate_updates_from_tensor(data, index, value, op_type):
|
|
985
1139
|
"""Generate an updates tensor from a tensor."""
|
|
986
1140
|
value = value.astype(data.dtype)
|
|
987
|
-
if F.is_sequence_value_unknown(F.shape(data)):
|
|
1141
|
+
if F.is_sequence_value_unknown(F.shape(data)) or F.is_sequence_value_unknown(F.shape(index)):
|
|
988
1142
|
data_shape = F.dyn_shape(data)
|
|
989
1143
|
index_shape = F.dyn_shape(index)
|
|
990
1144
|
updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type, True)
|
|
@@ -1025,13 +1179,49 @@ def tensor_setitem_by_number(self, index, value):
|
|
|
1025
1179
|
return tensor_setitem_by_number_with_sequence(self, index, value)
|
|
1026
1180
|
|
|
1027
1181
|
|
|
1182
|
+
def _tuple_index_transfer(broadcast_shape, final_shape, new_shape, x, all_empty_tensor):
|
|
1183
|
+
"""Transform tuple index tensor to the required."""
|
|
1184
|
+
if isinstance(broadcast_shape, Tensor):
|
|
1185
|
+
if not all_empty_tensor:
|
|
1186
|
+
x = F.broadcast_to(x, broadcast_shape)
|
|
1187
|
+
x = F.reshape(x, new_shape)
|
|
1188
|
+
x = F.broadcast_to(x, final_shape)
|
|
1189
|
+
return x
|
|
1190
|
+
item = _broadcast(broadcast_shape, x)
|
|
1191
|
+
return _broadcast(final_shape, F.reshape(item, new_shape))
|
|
1192
|
+
|
|
1193
|
+
|
|
1194
|
+
class _TensorIndexSetitem(base.TensorIndexSetitem_):
|
|
1195
|
+
"""
|
|
1196
|
+
Getting item of Tensor.
|
|
1197
|
+
|
|
1198
|
+
Args:
|
|
1199
|
+
data (Tensor): A tuple to be sliced.
|
|
1200
|
+
index: Index of tensor.
|
|
1201
|
+
|
|
1202
|
+
Returns:
|
|
1203
|
+
Type is the same as the element type of data.
|
|
1204
|
+
"""
|
|
1205
|
+
|
|
1206
|
+
def __call__(self, *args):
|
|
1207
|
+
pass
|
|
1208
|
+
|
|
1209
|
+
|
|
1210
|
+
_tensor_index_setitem = _TensorIndexSetitem('tensor_index_setitem')
|
|
1211
|
+
|
|
1212
|
+
|
|
1028
1213
|
def tensor_setitem_by_slice(self, index, value):
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
if
|
|
1033
|
-
return
|
|
1034
|
-
|
|
1214
|
+
"""Set a tensor item by slice."""
|
|
1215
|
+
indices, value_shape, start, stop, step, value = _tensor_index_setitem(
|
|
1216
|
+
self, index, value)
|
|
1217
|
+
if start == stop:
|
|
1218
|
+
return self
|
|
1219
|
+
value = F.broadcast_to(value, value_shape)
|
|
1220
|
+
if not const_utils.is_ascend() and step == 1:
|
|
1221
|
+
if isinstance(step, Tensor):
|
|
1222
|
+
return copy_slice(self, value, start, stop, step)
|
|
1223
|
+
return copy_slice(self, value, (start,), (stop,), (step,))
|
|
1224
|
+
return F.tensor_scatter_update(self, indices, value)
|
|
1035
1225
|
|
|
1036
1226
|
|
|
1037
1227
|
def tensor_setitem_by_ellipsis(self, index, value):
|
|
@@ -1049,8 +1239,6 @@ def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
|
|
|
1049
1239
|
updates = _generate_updates_from_tensor(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
|
|
1050
1240
|
data_shape = F.shape(data)
|
|
1051
1241
|
first_val = data_shape[0]
|
|
1052
|
-
if not F.isconstant(first_val):
|
|
1053
|
-
first_val = -1
|
|
1054
1242
|
index = F.select(index < 0, index + first_val, index)
|
|
1055
1243
|
index = F.expand_dims(index, -1)
|
|
1056
1244
|
if F.rank(index) < 2:
|
|
@@ -1081,13 +1269,12 @@ def tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
|
|
|
1081
1269
|
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
|
|
1082
1270
|
|
|
1083
1271
|
if F.is_sequence_value_unknown(F.shape(data)):
|
|
1084
|
-
|
|
1085
|
-
"Not supported to the dynamic shape tensor slice by using tensor of Boolean type")
|
|
1272
|
+
return tensor_setitem_by_tuple_with_tensor(data, (index,), value_tensor.astype(data.dtype))
|
|
1086
1273
|
return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
|
|
1087
1274
|
|
|
1088
1275
|
|
|
1089
1276
|
def tensor_setitem_by_tensor_with_number(data, index, value):
|
|
1090
|
-
value = F.
|
|
1277
|
+
value = F.cast(value, F.dtype(data))
|
|
1091
1278
|
return tensor_setitem_by_tensor_with_tensor(data, index, value)
|
|
1092
1279
|
|
|
1093
1280
|
|
|
@@ -1118,13 +1305,13 @@ def _tensor_setitem_by_bool_tensor_with_sequence(data, index, value):
|
|
|
1118
1305
|
|
|
1119
1306
|
def tensor_setitem_by_slice_with_number(data, input_slice, value):
|
|
1120
1307
|
"""Givens a scalar assign to tensor by slice"""
|
|
1121
|
-
value = F.
|
|
1308
|
+
value = F.cast(value, F.dtype(data))
|
|
1122
1309
|
return tensor_setitem_by_slice_with_tensor(data, input_slice, value)
|
|
1123
1310
|
|
|
1124
1311
|
|
|
1125
1312
|
def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
|
|
1126
1313
|
"""Assigns the tensor by tuple with number value."""
|
|
1127
|
-
value = F.
|
|
1314
|
+
value = F.cast(value, F.dtype(data))
|
|
1128
1315
|
return tensor_setitem_by_tuple_with_tensor(data, tuple_index, value)
|
|
1129
1316
|
|
|
1130
1317
|
|
|
@@ -1202,7 +1389,123 @@ def tensor_copy_slice_from_tuple(data, tuple_index, value):
|
|
|
1202
1389
|
return copy_slice(data, value.astype(data.dtype), start_tensor, stop_tensor, step_tensor)
|
|
1203
1390
|
|
|
1204
1391
|
|
|
1392
|
+
class _PreSetitemByTuple(base.PreSetitemByTuple_):
|
|
1393
|
+
"""
|
|
1394
|
+
Getting item of Tensor.
|
|
1395
|
+
|
|
1396
|
+
Args:
|
|
1397
|
+
data (Tensor): A tuple to be sliced.
|
|
1398
|
+
index: Index of tensor.
|
|
1399
|
+
|
|
1400
|
+
Returns:
|
|
1401
|
+
Type is the same as the element type of data.
|
|
1402
|
+
"""
|
|
1403
|
+
|
|
1404
|
+
def __init__(self, name):
|
|
1405
|
+
"""Initialize _PreSetitemByTuple."""
|
|
1406
|
+
base.PreSetitemByTuple_.__init__(self, name)
|
|
1407
|
+
|
|
1408
|
+
def __call__(self, *args):
|
|
1409
|
+
pass
|
|
1410
|
+
|
|
1411
|
+
|
|
1412
|
+
_pre_setitem_by_tuple = _PreSetitemByTuple('pre_setitem_by_tuple')
|
|
1413
|
+
|
|
1414
|
+
|
|
1415
|
+
class _HandleBoolTensor(base.HandleBoolTensor_):
|
|
1416
|
+
"""
|
|
1417
|
+
Getting item of Tensor.
|
|
1418
|
+
|
|
1419
|
+
Args:
|
|
1420
|
+
data (Tensor): A tuple to be sliced.
|
|
1421
|
+
index: Index of tensor.
|
|
1422
|
+
|
|
1423
|
+
Returns:
|
|
1424
|
+
Type is the same as the element type of data.
|
|
1425
|
+
"""
|
|
1426
|
+
|
|
1427
|
+
def __init__(self, name):
|
|
1428
|
+
"""Initialize _HandleBoolTensor."""
|
|
1429
|
+
base.HandleBoolTensor_.__init__(self, name)
|
|
1430
|
+
|
|
1431
|
+
def __call__(self, *args):
|
|
1432
|
+
pass
|
|
1433
|
+
|
|
1434
|
+
|
|
1435
|
+
_handle_bool_tensor = _HandleBoolTensor('handle_bool_tensor')
|
|
1436
|
+
|
|
1437
|
+
|
|
1438
|
+
class _HandleScalarTensorIndex(base.HandleScalarTensorIndex_):
|
|
1439
|
+
"""
|
|
1440
|
+
Getting item of Tensor.
|
|
1441
|
+
|
|
1442
|
+
Args:
|
|
1443
|
+
data (Tensor): A tuple to be sliced.
|
|
1444
|
+
index: Index of tensor.
|
|
1445
|
+
|
|
1446
|
+
Returns:
|
|
1447
|
+
Type is the same as the element type of data.
|
|
1448
|
+
"""
|
|
1449
|
+
|
|
1450
|
+
def __init__(self, name):
|
|
1451
|
+
"""Initialize _HandleBoolTensor."""
|
|
1452
|
+
base.HandleScalarTensorIndex_.__init__(self, name)
|
|
1453
|
+
|
|
1454
|
+
def __call__(self, *args):
|
|
1455
|
+
pass
|
|
1456
|
+
|
|
1457
|
+
|
|
1458
|
+
_handle_scalar_tensor_index = _HandleScalarTensorIndex('handle_scalar_tensor_index')
|
|
1459
|
+
|
|
1460
|
+
|
|
1205
1461
|
def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
|
1462
|
+
"""Assigns the tensor by tuple with tensor value."""
|
|
1463
|
+
if const_utils.use_copy_slice(tuple_index) and not const_utils.is_ascend():
|
|
1464
|
+
if F.is_sequence_value_unknown(F.shape(data)):
|
|
1465
|
+
return tensor_copy_slice_from_tuple(data, tuple_index, value)
|
|
1466
|
+
dim1_start, dim1_stop, _ = const_utils.normalize_slice(
|
|
1467
|
+
tuple_index[1], data.shape[1])
|
|
1468
|
+
if dim1_stop - dim1_start <= 0:
|
|
1469
|
+
return data
|
|
1470
|
+
dim0_start = tuple_index[0] if tuple_index[0] >= 0 else tuple_index[0] + data.shape[0]
|
|
1471
|
+
start = (dim0_start, dim1_start)
|
|
1472
|
+
stop = (dim0_start + 1, dim1_stop)
|
|
1473
|
+
step = (1, 1)
|
|
1474
|
+
value_shape = (dim1_stop - dim1_start,) + \
|
|
1475
|
+
const_utils.tuple_slice(data.shape, 2, None)
|
|
1476
|
+
value = _broadcast(value_shape, value)
|
|
1477
|
+
return copy_slice(data, value.astype(data.dtype), start, stop, step)
|
|
1478
|
+
tuple_index, _, non_zero_shapes = _handle_bool_tensor(tuple_index)
|
|
1479
|
+
|
|
1480
|
+
for non_zero_shape in non_zero_shapes:
|
|
1481
|
+
if F.reduce_min(non_zero_shape) == 0:
|
|
1482
|
+
return data
|
|
1483
|
+
value = value.astype(data.dtype)
|
|
1484
|
+
special_index, tuple_index, new_value_shape, idx_advanced, _broadcast_data_shape \
|
|
1485
|
+
= _pre_setitem_by_tuple(data, tuple_index, value)
|
|
1486
|
+
if special_index == 0:
|
|
1487
|
+
return data
|
|
1488
|
+
value = F.reshape(value, new_value_shape)
|
|
1489
|
+
if not tuple_index or special_index == 1:
|
|
1490
|
+
data[True] = value
|
|
1491
|
+
return data
|
|
1492
|
+
|
|
1493
|
+
empty_broadcast_data_shape = False
|
|
1494
|
+
if isinstance(_broadcast_data_shape, Tensor) and _broadcast_data_shape == Tensor([0]):
|
|
1495
|
+
empty_broadcast_data_shape = True
|
|
1496
|
+
if isinstance(_broadcast_data_shape, tuple) and not _broadcast_data_shape:
|
|
1497
|
+
empty_broadcast_data_shape = True
|
|
1498
|
+
indices = _tensor_index_setitem(
|
|
1499
|
+
data, tuple_index, value, idx_advanced, empty_broadcast_data_shape)
|
|
1500
|
+
|
|
1501
|
+
updates = _generate_updates_from_tensor(
|
|
1502
|
+
data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
|
|
1503
|
+
if is_parameter(data):
|
|
1504
|
+
F.scatter_nd_update(data, indices, updates)
|
|
1505
|
+
return data
|
|
1506
|
+
return F.tensor_scatter_update(data, indices, updates)
|
|
1507
|
+
|
|
1508
|
+
def tensor_itemset_by_tuple_with_tensor(data, tuple_index, value):
|
|
1206
1509
|
"""Assigns the tensor by tuple with tensor value."""
|
|
1207
1510
|
op_name = const_utils.TENSOR_SETITEM
|
|
1208
1511
|
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
|
@@ -1220,7 +1523,6 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
|
|
1220
1523
|
value_shape = (dim1_stop - dim1_start,) + const_utils.tuple_slice(data.shape, 2, None)
|
|
1221
1524
|
value = _broadcast(value_shape, value)
|
|
1222
1525
|
return copy_slice(data, value.astype(data.dtype), start, stop, step)
|
|
1223
|
-
|
|
1224
1526
|
tuple_index, value, idx_advanced = remove_expanded_dims(tuple_index, F.shape(data), value)
|
|
1225
1527
|
|
|
1226
1528
|
if tuple_index is False:
|
|
@@ -1248,7 +1550,7 @@ def tensor_setitem_by_tuple_with_sequence(data, tuple_index, value):
|
|
|
1248
1550
|
|
|
1249
1551
|
def tensor_setitem_by_number_with_number(data, index, value):
|
|
1250
1552
|
"""Assigns the tensor by number with number value."""
|
|
1251
|
-
value = F.
|
|
1553
|
+
value = F.cast(value, F.dtype(data))
|
|
1252
1554
|
return tensor_setitem_by_number_with_tensor(data, index, value)
|
|
1253
1555
|
|
|
1254
1556
|
|
|
@@ -1283,7 +1585,7 @@ def tensor_setitem_by_ellipsis_with_number(data, value):
|
|
|
1283
1585
|
data_shape = F.shape(data)
|
|
1284
1586
|
data_dtype = F.dtype(data)
|
|
1285
1587
|
if F.is_sequence_value_unknown(data_shape):
|
|
1286
|
-
value = F.
|
|
1588
|
+
value = F.cast(value, F.dtype(data))
|
|
1287
1589
|
return tensor_setitem_by_ellipsis_with_tensor(data, value)
|
|
1288
1590
|
return F.fill(data_dtype, data_shape, value)
|
|
1289
1591
|
|
|
@@ -1315,6 +1617,7 @@ def tensor_setitem_by_ellipsis_with_sequence(data, value):
|
|
|
1315
1617
|
def tensor_setitem_by_bool(data, index, value):
|
|
1316
1618
|
"""Assigns a value to the tensor by boolean."""
|
|
1317
1619
|
data_shape = F.shape(data)
|
|
1620
|
+
data_dtype = F.dtype(data)
|
|
1318
1621
|
if not index:
|
|
1319
1622
|
data_shape = (0,) + data_shape
|
|
1320
1623
|
if isinstance(value, (list, tuple)):
|
|
@@ -1326,6 +1629,7 @@ def tensor_setitem_by_bool(data, index, value):
|
|
|
1326
1629
|
|
|
1327
1630
|
if F.is_sequence_value_unknown(data_shape) and index:
|
|
1328
1631
|
data_shape = F.dyn_shape(data)
|
|
1632
|
+
value = value.astype(data_dtype)
|
|
1329
1633
|
data = ops.broadcast_to(value, data_shape)
|
|
1330
1634
|
return data
|
|
1331
1635
|
value_shape = F.shape(value)
|
|
@@ -1333,7 +1637,7 @@ def tensor_setitem_by_bool(data, index, value):
|
|
|
1333
1637
|
if index:
|
|
1334
1638
|
value = F.reshape(value, source_shape)
|
|
1335
1639
|
value = _broadcast(data_shape, value)
|
|
1336
|
-
data = value
|
|
1640
|
+
data = F.cast(value, data_dtype)
|
|
1337
1641
|
return data
|
|
1338
1642
|
|
|
1339
1643
|
|
|
@@ -1417,8 +1721,8 @@ def remove_expanded_dims(tuple_index, data_shape, value):
|
|
|
1417
1721
|
elif const_utils.is_slice(index_out):
|
|
1418
1722
|
indices_out += (index_out,)
|
|
1419
1723
|
not_expanded_dim += (True,)
|
|
1420
|
-
|
|
1421
|
-
|
|
1724
|
+
has_false = has_false or parse_check_slice_index(
|
|
1725
|
+
index_out, data_shape[cur_dim])
|
|
1422
1726
|
cur_dim += 1
|
|
1423
1727
|
elif isinstance(index_out, (Tensor, bool)): # advanced index
|
|
1424
1728
|
if idx_advanced == -1:
|
|
@@ -1490,7 +1794,7 @@ def reduce_(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None,
|
|
|
1490
1794
|
ndim = F.rank(a)
|
|
1491
1795
|
if dtype is None:
|
|
1492
1796
|
dtype = F.dtype(a)
|
|
1493
|
-
axes =
|
|
1797
|
+
axes = validator.check_axis_valid(axis, ndim)
|
|
1494
1798
|
if initial is not None:
|
|
1495
1799
|
if ((isinstance(initial, Tensor) and F.rank(initial) > 0) or
|
|
1496
1800
|
not isinstance(initial, (int, float, bool, Tensor))):
|
|
@@ -1505,18 +1809,20 @@ def reduce_(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None,
|
|
|
1505
1809
|
initial = F.fill(dtype, shape, initial)
|
|
1506
1810
|
a = cmp_fn(a, initial)
|
|
1507
1811
|
|
|
1508
|
-
if isinstance(where, Tensor):
|
|
1812
|
+
if where is not None and not isinstance(where, Tensor):
|
|
1813
|
+
where = Tensor(where, dtype=mstype.bool_)
|
|
1814
|
+
|
|
1815
|
+
if where is not None and (where.shape or not where):
|
|
1509
1816
|
if initial is None:
|
|
1510
1817
|
const_utils.raise_value_error('initial value must be provided for where masks')
|
|
1511
1818
|
ndim_orig = F.rank(a)
|
|
1512
1819
|
# broadcasts input tensors
|
|
1513
1820
|
shape_out = const_utils.infer_out_shape(F.shape(where), F.shape(a), F.shape(initial))
|
|
1514
|
-
broadcast_to = P.BroadcastTo(shape_out)
|
|
1515
1821
|
where = where.astype(mstype.float32)
|
|
1516
|
-
where = broadcast_to(where)
|
|
1822
|
+
where = F.broadcast_to(where, shape_out)
|
|
1517
1823
|
where = where.astype(mstype.bool_)
|
|
1518
|
-
a = broadcast_to(a)
|
|
1519
|
-
initial = broadcast_to(initial)
|
|
1824
|
+
a = F.broadcast_to(a, shape_out)
|
|
1825
|
+
initial = F.broadcast_to(initial, shape_out)
|
|
1520
1826
|
a = F.select(where, a, initial)
|
|
1521
1827
|
axes = const_utils.real_axes(ndim_orig, F.rank(a), axes)
|
|
1522
1828
|
|