mindspore 2.1.0__cp37-cp37m-manylinux1_x86_64.whl → 2.2.11__cp37-cp37m-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +139 -22
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
- mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
- mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
- mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
- mindspore/_akg/akg/utils/kernel_exec.py +98 -274
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
- mindspore/_akg/akg/utils/util.py +56 -1
- mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +13 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +67 -72
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +86 -106
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +29 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +33 -7
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +61 -95
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +87 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +31 -19
- mindspore/ops/operations/_grad_ops.py +71 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +192 -144
- mindspore/ops/operations/nn_ops.py +857 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +42 -21
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/ops.py +55 -5
- mindspore/scipy/optimize/__init__.py +3 -2
- mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +488 -539
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
|
@@ -24,7 +24,8 @@ from mindspore.ops import operations as P
|
|
|
24
24
|
from mindspore.ops.composite import base
|
|
25
25
|
from mindspore.ops._primitive_cache import _get_cache_prim
|
|
26
26
|
from mindspore.ops.operations._inner_ops import TensorCopySlices, SliceGetItem, \
|
|
27
|
-
TopTypeof, issubclass_, IsParameter, GetitemTensorIndexInfo, SetitemTensorIndexInfo
|
|
27
|
+
TopTypeof, issubclass_, IsParameter, GetitemTensorIndexInfo, SetitemTensorIndexInfo, \
|
|
28
|
+
SelectView, CopyWithSlice
|
|
28
29
|
from mindspore.common import dtype as mstype
|
|
29
30
|
from mindspore.common._register_for_tensor import tensor_operator_registry
|
|
30
31
|
from mindspore.common.initializer import Zero
|
|
@@ -33,6 +34,7 @@ from mindspore.common import mutable
|
|
|
33
34
|
from mindspore import ops
|
|
34
35
|
from mindspore.ops.primitive import _primexpr
|
|
35
36
|
from mindspore import _checkparam as validator
|
|
37
|
+
from mindspore.common._stub_tensor import _convert_stub
|
|
36
38
|
|
|
37
39
|
slice_get_item = SliceGetItem()
|
|
38
40
|
hyper_map = base.HyperMap()
|
|
@@ -43,6 +45,8 @@ is_parameter = IsParameter()
|
|
|
43
45
|
getitem_tensor_index_info = GetitemTensorIndexInfo(const_utils.is_ascend())
|
|
44
46
|
setitem_tensor_index_info = SetitemTensorIndexInfo(const_utils.is_ascend())
|
|
45
47
|
|
|
48
|
+
selevt_view = SelectView()
|
|
49
|
+
copy_with_slice = CopyWithSlice()
|
|
46
50
|
|
|
47
51
|
def strided_slice(data, begin_strides, end_strides, step_strides, begin_mask=0, end_mask=0, ellipsis_mask=0,
|
|
48
52
|
new_axis_mask=0, shrink_axis_mask=0):
|
|
@@ -66,19 +70,23 @@ class ValueTransferType(IntEnum):
|
|
|
66
70
|
kGatherND = 9
|
|
67
71
|
kScatterNdUpdate = 10
|
|
68
72
|
kReshape = 11
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
73
|
+
kSelectView = 12
|
|
74
|
+
kUnsqueeze = 13
|
|
75
|
+
kCopyView = 14
|
|
76
|
+
kScatterND = 15
|
|
77
|
+
kNumberToTensor = 16
|
|
78
|
+
kHandleSequenceValue = 17
|
|
79
|
+
kByPass = 18
|
|
80
|
+
kReSetItemByIndex = 19
|
|
81
|
+
kCopySlice = 20
|
|
82
|
+
kSetItemByBool = 21
|
|
83
|
+
kEmptyTensor = 22
|
|
84
|
+
kSetItemByEllipsis = 23
|
|
85
|
+
kFormatIndexTensor = 24
|
|
86
|
+
kGetitemByBoolTensor = 25
|
|
87
|
+
kSetitemByBoolTensor = 26
|
|
88
|
+
kJustReturn = 27
|
|
89
|
+
kRaiseIndexError = 28
|
|
82
90
|
|
|
83
91
|
|
|
84
92
|
def data_update(transfer_types, args, data, new_index, value=None):
|
|
@@ -86,11 +94,14 @@ def data_update(transfer_types, args, data, new_index, value=None):
|
|
|
86
94
|
We finally generate a new tensor when handling tensor getitem/setitem
|
|
87
95
|
by transfer data and value with index.
|
|
88
96
|
"""
|
|
97
|
+
origin_data = data
|
|
89
98
|
for transfer_type, arg in zip(transfer_types, args):
|
|
90
99
|
if transfer_type == ValueTransferType.kUnknown:
|
|
91
100
|
raise IndexError(f"Inlvaid transfer type {transfer_type}.")
|
|
92
101
|
if transfer_type <= ValueTransferType.kScatterND:
|
|
93
|
-
data = data_update_by_ops(transfer_type, arg, data, new_index, value)
|
|
102
|
+
data = data_update_by_ops(transfer_type, arg, data, new_index, origin_data, value)
|
|
103
|
+
if transfer_type == ValueTransferType.kJustReturn:
|
|
104
|
+
return _convert_stub(arg)
|
|
94
105
|
if transfer_type == ValueTransferType.kSetItemByBool:
|
|
95
106
|
return tensor_setitem_by_bool(data, new_index, value)
|
|
96
107
|
if transfer_type == ValueTransferType.kCopySlice:
|
|
@@ -114,7 +125,7 @@ def data_update(transfer_types, args, data, new_index, value=None):
|
|
|
114
125
|
return data
|
|
115
126
|
|
|
116
127
|
|
|
117
|
-
def data_update_by_ops(transfer_type, arg, data, new_index, value=None):
|
|
128
|
+
def data_update_by_ops(transfer_type, arg, data, new_index, origin_data, value=None):
|
|
118
129
|
"""
|
|
119
130
|
Generate a new tensor when handling tensor getitem/setitem
|
|
120
131
|
by ops.
|
|
@@ -135,14 +146,22 @@ def data_update_by_ops(transfer_type, arg, data, new_index, value=None):
|
|
|
135
146
|
F.scatter_nd_update(data, new_index, value)
|
|
136
147
|
elif transfer_type == ValueTransferType.kSelect:
|
|
137
148
|
data = F.select(Tensor(new_index), value, data)
|
|
149
|
+
elif transfer_type == ValueTransferType.kSelectView:
|
|
150
|
+
data = selevt_view(data, arg[0], arg[1])
|
|
151
|
+
elif transfer_type == ValueTransferType.kCopyView:
|
|
152
|
+
value = _broadcast(F.shape(data), F.cast(value, F.dtype(data)))
|
|
153
|
+
data = copy_with_slice(data, value)
|
|
154
|
+
return origin_data
|
|
138
155
|
elif transfer_type == ValueTransferType.kReshape:
|
|
139
156
|
data = F.reshape(data, arg)
|
|
140
157
|
elif transfer_type == ValueTransferType.kGather:
|
|
141
158
|
data = F.gather(data, new_index, 0)
|
|
142
159
|
elif transfer_type == ValueTransferType.kExpandDims:
|
|
143
160
|
data = F.expand_dims(data, 0)
|
|
161
|
+
elif transfer_type == ValueTransferType.kUnsqueeze:
|
|
162
|
+
data = F.unsqueeze(data, arg)
|
|
144
163
|
elif transfer_type == ValueTransferType.kStrideSlice:
|
|
145
|
-
data =
|
|
164
|
+
data = strided_slice(data, arg[0], arg[1], arg[2])
|
|
146
165
|
else:
|
|
147
166
|
raise IndexError(f"Inlvaid transfer type {transfer_type}.")
|
|
148
167
|
return data
|
|
@@ -154,7 +173,7 @@ def value_update(transfer_types, args, data, value):
|
|
|
154
173
|
if transfer_type == ValueTransferType.kByPass:
|
|
155
174
|
continue
|
|
156
175
|
if transfer_type == ValueTransferType.kNumberToTensor:
|
|
157
|
-
value = F.
|
|
176
|
+
value = F.cast(value, F.dtype(data))
|
|
158
177
|
elif transfer_type == ValueTransferType.kHandleSequenceValue:
|
|
159
178
|
op_type, index = arg
|
|
160
179
|
if op_type == const_utils.SET_ITEM_BY_ONE_TENSOR:
|
|
@@ -192,7 +211,10 @@ def _tensor_setitem(self, index, value):
|
|
|
192
211
|
data_update_types = setitem_info[3]
|
|
193
212
|
data_update_args = setitem_info[4]
|
|
194
213
|
value = value_update(v_transfer_types, v_transfer_args, self, value)
|
|
195
|
-
|
|
214
|
+
output = data_update(data_update_types, data_update_args, self, new_index, value)
|
|
215
|
+
if new_index == "view":
|
|
216
|
+
return (self,)
|
|
217
|
+
return output
|
|
196
218
|
|
|
197
219
|
|
|
198
220
|
tensor_operator_registry.register("__getitem__", _tensor_getitem)
|
|
@@ -286,7 +308,7 @@ def _scalar_to_tensor(input_x):
|
|
|
286
308
|
@_primexpr
|
|
287
309
|
def _check_scalar_tensor_args(args):
|
|
288
310
|
"""For the item, check that the index of the scalar tensor is set."""
|
|
289
|
-
if args
|
|
311
|
+
if args not in ((None,), ()):
|
|
290
312
|
const_utils.raise_value_error("For item, the index of scalar Tensor should not be set.")
|
|
291
313
|
|
|
292
314
|
|
|
@@ -295,15 +317,15 @@ def tensor_item(data, *args):
|
|
|
295
317
|
# transform a.item(tuple(int)) -> a.item(int1,int2...intN)
|
|
296
318
|
if data.ndim == 0:
|
|
297
319
|
_check_scalar_tensor_args(args)
|
|
298
|
-
return data
|
|
320
|
+
return data.asnumpy().item()
|
|
299
321
|
if len(args) == 1 and isinstance(args[0], tuple):
|
|
300
322
|
args = args[0]
|
|
301
323
|
|
|
302
324
|
args_types = hyper_map(F.typeof, args)
|
|
303
325
|
if not args or const_utils.judge_index_type(args_types[0], mstype.type_none):
|
|
304
326
|
if data.shape == (1,):
|
|
305
|
-
return data
|
|
306
|
-
const_utils.raise_value_error("Can only convert an array of size 1 to a
|
|
327
|
+
return data.asnumpy().item()
|
|
328
|
+
const_utils.raise_value_error("Can only convert an array of size 1 to a Python scalar")
|
|
307
329
|
|
|
308
330
|
if not const_utils.judge_indexes_types(args_types, mstype.int64):
|
|
309
331
|
const_utils.raise_type_error("The index object cannot be interpreted as an integer")
|
|
@@ -362,7 +384,8 @@ def tensor_itemset_by_tuple_with_number(data, tuple_index, nubmer_value):
|
|
|
362
384
|
exp_msg = const_utils.gen_exception_msg(
|
|
363
385
|
"Tuple index len({}) is not same to tensor dimension({})", len(tuple_index), data.ndim)
|
|
364
386
|
const_utils.raise_index_error(exp_msg)
|
|
365
|
-
|
|
387
|
+
nubmer_value = F.cast(nubmer_value, F.dtype(data))
|
|
388
|
+
return tensor_itemset_by_tuple_with_tensor(data, tuple_index, nubmer_value)
|
|
366
389
|
|
|
367
390
|
|
|
368
391
|
def _broadcast(broadcast_shape, x):
|
|
@@ -530,10 +553,6 @@ class _TensorIndexGetitem(base.TensorIndexGetitem_):
|
|
|
530
553
|
Type is the same as the element type of data.
|
|
531
554
|
"""
|
|
532
555
|
|
|
533
|
-
def __init__(self, name):
|
|
534
|
-
"""Initialize _TensorIndexGetitem."""
|
|
535
|
-
base.TensorIndexGetitem_.__init__(self, name)
|
|
536
|
-
|
|
537
556
|
def __call__(self, *args):
|
|
538
557
|
pass
|
|
539
558
|
|
|
@@ -580,9 +599,12 @@ def _tensor_index_by_bool(data, bool_value):
|
|
|
580
599
|
"""Tensor getitem by a single bool value"""
|
|
581
600
|
min_data_dim, max_data_dim = 0, 7
|
|
582
601
|
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
602
|
+
output = data
|
|
583
603
|
if bool_value:
|
|
584
|
-
|
|
585
|
-
|
|
604
|
+
output = F.expand_dims(data, 0)
|
|
605
|
+
elif not F.is_sequence_value_unknown(F.shape(data)):
|
|
606
|
+
return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.")
|
|
607
|
+
return output
|
|
586
608
|
|
|
587
609
|
|
|
588
610
|
def get_stride_info_from_integer(tensor_int):
|
|
@@ -599,15 +621,14 @@ def get_stride_info_from_integer(tensor_int):
|
|
|
599
621
|
def _tensor_index_by_integer(data, int_index):
|
|
600
622
|
"""Tensor getitem by a single integer number"""
|
|
601
623
|
data_shape = F.shape(data)
|
|
602
|
-
if not data_shape:
|
|
603
|
-
const_utils.raise_type_error("Cannot iterate over a scalar tensor.")
|
|
604
|
-
if data.ndim < 1 or data.ndim > 8:
|
|
605
|
-
const_utils.raise_value_error("Expect Tensor to have dimension between 1 and 8.")
|
|
606
|
-
|
|
607
624
|
if F.is_sequence_value_unknown(data_shape) or not F.isconstant(int_index):
|
|
608
625
|
tensor_index = _scalar_to_tensor(int_index)
|
|
609
626
|
begin_strides, end_strides, step_strides = get_stride_info_from_integer(tensor_index)
|
|
610
627
|
else:
|
|
628
|
+
if not data_shape:
|
|
629
|
+
const_utils.raise_type_error("Cannot iterate over a scalar tensor.")
|
|
630
|
+
if data.ndim < 1 or data.ndim > 8:
|
|
631
|
+
const_utils.raise_value_error("Expect Tensor to have dimension between 1 and 8.")
|
|
611
632
|
transformed_number = const_utils.check_range(int_index, data_shape[0])
|
|
612
633
|
begin_strides, end_strides, step_strides = \
|
|
613
634
|
const_utils.get_stride_info_from_integer(data_shape, transformed_number)
|
|
@@ -619,7 +640,6 @@ def _tensor_index_by_integer(data, int_index):
|
|
|
619
640
|
end_mask += 2 ** i
|
|
620
641
|
return strided_slice(data, begin_strides, end_strides, step_strides, begin_mask, end_mask, 0, 0, shrink_axis_mask)
|
|
621
642
|
|
|
622
|
-
|
|
623
643
|
def _check_dim_shape_valid(data, tensor_index):
|
|
624
644
|
"""check dim and shape of tensor_index for tensor(bool) indexing"""
|
|
625
645
|
if data.ndim < tensor_index.ndim:
|
|
@@ -632,7 +652,8 @@ def _check_dim_shape_valid(data, tensor_index):
|
|
|
632
652
|
|
|
633
653
|
def tensor_index_by_bool_tensor(data, tensor_index):
|
|
634
654
|
"""Tensor getitem by a bool tensor"""
|
|
635
|
-
|
|
655
|
+
if not F.is_sequence_value_unknown(F.shape(data)):
|
|
656
|
+
_check_dim_shape_valid(data, tensor_index)
|
|
636
657
|
tensor_index = tensor_index.nonzero()
|
|
637
658
|
return F.gather_nd(data, tensor_index)
|
|
638
659
|
|
|
@@ -640,7 +661,8 @@ def tensor_index_by_bool_tensor(data, tensor_index):
|
|
|
640
661
|
def tensor_index_by_tensor(data, tensor_index):
|
|
641
662
|
"""Tensor getitem by a single tensor"""
|
|
642
663
|
min_data_dim, max_data_dim = 0, 7
|
|
643
|
-
|
|
664
|
+
if not F.is_sequence_value_unknown(F.shape(data)):
|
|
665
|
+
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
644
666
|
if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Int):
|
|
645
667
|
return F.gather(data, tensor_index, 0)
|
|
646
668
|
if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Bool):
|
|
@@ -658,16 +680,22 @@ def tensor_index_by_list(data, list_index):
|
|
|
658
680
|
|
|
659
681
|
data_shape = F.shape(data)
|
|
660
682
|
indexes_types = hyper_map(toptypeof, list_index)
|
|
661
|
-
if const_utils.check_type_isinstance(indexes_types, (mstype.Bool, mstype.Int))
|
|
683
|
+
if const_utils.check_type_isinstance(indexes_types, (mstype.Bool, mstype.Int)) \
|
|
684
|
+
and not F.is_sequence_value_unknown(list_index):
|
|
662
685
|
if not F.isconstant(data_shape[0]):
|
|
663
686
|
if all(isinstance(i, bool) for i in list_index):
|
|
664
|
-
|
|
665
|
-
|
|
687
|
+
if F.dyn_shape(data)[0] != len(list_index):
|
|
688
|
+
raise IndexError(
|
|
689
|
+
f'dimension is {F.dyn_shape(data)[0]} but corresponding boolean dimension is {len(list_index)}')
|
|
690
|
+
tensor_index = Tensor(list_index).nonzero()
|
|
691
|
+
return F.gather_nd(data, tensor_index)
|
|
666
692
|
tensor_index = const_utils.sequence_to_index(list_index, None)
|
|
667
693
|
else:
|
|
668
|
-
tensor_index = const_utils.sequence_to_index(
|
|
694
|
+
tensor_index = const_utils.sequence_to_index(
|
|
695
|
+
list_index, data_shape[0])
|
|
669
696
|
if tensor_index is False:
|
|
670
|
-
const_utils.raise_index_error(
|
|
697
|
+
const_utils.raise_index_error(
|
|
698
|
+
"When tensor is indexed by list, the list can't be empty.")
|
|
671
699
|
return F.gather(data, tensor_index, 0)
|
|
672
700
|
|
|
673
701
|
tuple_index_new = ()
|
|
@@ -693,6 +721,29 @@ def judge_tuple_index_dim_check_error(index_dim, data_dim):
|
|
|
693
721
|
f"dim of index:{index_dim}, dim of data:{data_dim}")
|
|
694
722
|
|
|
695
723
|
|
|
724
|
+
class _HandleEmptySlice(base.HandleEmptySlice_):
|
|
725
|
+
"""
|
|
726
|
+
Getting item of Tensor.
|
|
727
|
+
|
|
728
|
+
Args:
|
|
729
|
+
data (Tensor): A tuple to be sliced.
|
|
730
|
+
index: Index of tensor.
|
|
731
|
+
|
|
732
|
+
Returns:
|
|
733
|
+
Type is the same as the element type of data.
|
|
734
|
+
"""
|
|
735
|
+
|
|
736
|
+
def __init__(self, name):
|
|
737
|
+
"""Initialize _HandleEmptySlice."""
|
|
738
|
+
base.HandleEmptySlice_.__init__(self, name)
|
|
739
|
+
|
|
740
|
+
def __call__(self, *args):
|
|
741
|
+
pass
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
_handle_empty_slice = _HandleEmptySlice('handle_zero_tuple_index')
|
|
745
|
+
|
|
746
|
+
|
|
696
747
|
def judge_tuple_index_dim(data, tuple_index):
|
|
697
748
|
"""Judge whether tuple_index's dim is valid"""
|
|
698
749
|
data_dim = data.ndim
|
|
@@ -700,29 +751,55 @@ def judge_tuple_index_dim(data, tuple_index):
|
|
|
700
751
|
for index in tuple_index:
|
|
701
752
|
if isinstance(toptypeof(index), mstype.TensorType) and index.dtype == mstype.bool_:
|
|
702
753
|
index_dim += index.ndim
|
|
703
|
-
|
|
754
|
+
elif not isinstance(toptypeof(index), (mstype.NoneType, mstype.Ellipsis_, mstype.Bool)):
|
|
704
755
|
index_dim += 1
|
|
705
756
|
judge_tuple_index_dim_check_error(index_dim, data_dim)
|
|
706
757
|
|
|
707
758
|
|
|
759
|
+
def judge_simple_tuple_index(data, tuple_index):
|
|
760
|
+
"""Judge whether tuple_index is simple index, which not rollback to cpu ops."""
|
|
761
|
+
op_name = const_utils.TENSOR_GETITEM
|
|
762
|
+
indexes_types = hyper_map(toptypeof, tuple_index)
|
|
763
|
+
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
|
|
764
|
+
return F.isconstant(tuple_index) and contain_type == const_utils.ALL_BASIC \
|
|
765
|
+
and F.is_sequence_value_unknown(F.shape(data)) and F.isconstant(F.rank(data))
|
|
766
|
+
|
|
767
|
+
|
|
708
768
|
def tensor_index_by_tuple(data, tuple_index):
|
|
709
769
|
"""Tensor getitem by tuple of various types with None"""
|
|
710
770
|
if not tuple_index:
|
|
711
771
|
return data
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
min_data_dim, max_data_dim = 1, 8
|
|
719
|
-
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
720
|
-
judge_tuple_index_dim(data, tuple_index)
|
|
721
|
-
indexes_types = hyper_map(toptypeof, tuple_index)
|
|
722
|
-
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
|
|
723
|
-
if contain_type == const_utils.ALL_BASIC:
|
|
772
|
+
if judge_simple_tuple_index(data, tuple_index):
|
|
773
|
+
tuple_index = convert_tupleslice_to_tensor(tuple_index)
|
|
774
|
+
op_name = const_utils.TENSOR_GETITEM
|
|
775
|
+
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
|
776
|
+
min_data_dim, max_data_dim = 1, 8
|
|
777
|
+
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
724
778
|
return _tensor_getitem_by_tuple_slice(data, tuple_index)
|
|
725
|
-
|
|
779
|
+
|
|
780
|
+
if not F.is_sequence_value_unknown(F.shape(data)):
|
|
781
|
+
judge_tuple_index_dim(data, tuple_index)
|
|
782
|
+
tuple_index, zero_index, non_zero_shapes = _handle_bool_tensor(tuple_index)
|
|
783
|
+
for non_zero_shape in non_zero_shapes:
|
|
784
|
+
if F.reduce_min(non_zero_shape) == 0:
|
|
785
|
+
tuple_index = zero_index
|
|
786
|
+
break
|
|
787
|
+
if not F.is_sequence_value_unknown(F.shape(data)) and F.isconstant(tuple_index):
|
|
788
|
+
_, stub_zero_dim_tensor = _handle_empty_slice(data, tuple_index)
|
|
789
|
+
if 0 in stub_zero_dim_tensor.shape:
|
|
790
|
+
return F.fill(data.dtype, stub_zero_dim_tensor.shape, 0)
|
|
791
|
+
has_tensor_index = False
|
|
792
|
+
for i in tuple_index:
|
|
793
|
+
if isinstance(i, Tensor):
|
|
794
|
+
has_tensor_index = True
|
|
795
|
+
break
|
|
796
|
+
empty_broadcast_data_shape = False
|
|
797
|
+
_broadcast_data_shape = _handle_scalar_tensor_index(data, tuple_index)
|
|
798
|
+
if has_tensor_index and isinstance(_broadcast_data_shape, Tensor) and _broadcast_data_shape == Tensor([0]):
|
|
799
|
+
empty_broadcast_data_shape = True
|
|
800
|
+
if has_tensor_index and isinstance(_broadcast_data_shape, tuple) and not _broadcast_data_shape:
|
|
801
|
+
empty_broadcast_data_shape = True
|
|
802
|
+
return _tensor_index_getitem(data, tuple_index, empty_broadcast_data_shape)
|
|
726
803
|
|
|
727
804
|
|
|
728
805
|
def get_slice_stride(slice_index, dim_size):
|
|
@@ -1039,7 +1116,7 @@ def sequence_to_tensor(value, dtype):
|
|
|
1039
1116
|
|
|
1040
1117
|
if value_elements_type == const_utils.ALL_TENSOR:
|
|
1041
1118
|
value = F.stack(value).astype(dtype)
|
|
1042
|
-
elif value_elements_type == const_utils.NO_TENSOR:
|
|
1119
|
+
elif value_elements_type == const_utils.NO_TENSOR and not F.is_sequence_value_unknown(value):
|
|
1043
1120
|
value = const_utils.make_tensor(value, dtype)
|
|
1044
1121
|
else:
|
|
1045
1122
|
new_value = ()
|
|
@@ -1061,7 +1138,7 @@ def _generate_updates_from_sequence(data, index, value, op_type):
|
|
|
1061
1138
|
def _generate_updates_from_tensor(data, index, value, op_type):
|
|
1062
1139
|
"""Generate an updates tensor from a tensor."""
|
|
1063
1140
|
value = value.astype(data.dtype)
|
|
1064
|
-
if F.is_sequence_value_unknown(F.shape(data)):
|
|
1141
|
+
if F.is_sequence_value_unknown(F.shape(data)) or F.is_sequence_value_unknown(F.shape(index)):
|
|
1065
1142
|
data_shape = F.dyn_shape(data)
|
|
1066
1143
|
index_shape = F.dyn_shape(index)
|
|
1067
1144
|
updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type, True)
|
|
@@ -1102,6 +1179,18 @@ def tensor_setitem_by_number(self, index, value):
|
|
|
1102
1179
|
return tensor_setitem_by_number_with_sequence(self, index, value)
|
|
1103
1180
|
|
|
1104
1181
|
|
|
1182
|
+
def _tuple_index_transfer(broadcast_shape, final_shape, new_shape, x, all_empty_tensor):
|
|
1183
|
+
"""Transform tuple index tensor to the required."""
|
|
1184
|
+
if isinstance(broadcast_shape, Tensor):
|
|
1185
|
+
if not all_empty_tensor:
|
|
1186
|
+
x = F.broadcast_to(x, broadcast_shape)
|
|
1187
|
+
x = F.reshape(x, new_shape)
|
|
1188
|
+
x = F.broadcast_to(x, final_shape)
|
|
1189
|
+
return x
|
|
1190
|
+
item = _broadcast(broadcast_shape, x)
|
|
1191
|
+
return _broadcast(final_shape, F.reshape(item, new_shape))
|
|
1192
|
+
|
|
1193
|
+
|
|
1105
1194
|
class _TensorIndexSetitem(base.TensorIndexSetitem_):
|
|
1106
1195
|
"""
|
|
1107
1196
|
Getting item of Tensor.
|
|
@@ -1114,10 +1203,6 @@ class _TensorIndexSetitem(base.TensorIndexSetitem_):
|
|
|
1114
1203
|
Type is the same as the element type of data.
|
|
1115
1204
|
"""
|
|
1116
1205
|
|
|
1117
|
-
def __init__(self, name):
|
|
1118
|
-
"""Initialize _TensorIndexGetitem."""
|
|
1119
|
-
base.TensorIndexSetitem_.__init__(self, name)
|
|
1120
|
-
|
|
1121
1206
|
def __call__(self, *args):
|
|
1122
1207
|
pass
|
|
1123
1208
|
|
|
@@ -1170,7 +1255,8 @@ def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
|
|
|
1170
1255
|
index = index.reshape(const_utils.generate_padding_shape(index.shape, len(data.shape)))
|
|
1171
1256
|
index = F.broadcast_to(index, data.shape)
|
|
1172
1257
|
value = F.cast(value, F.dtype(data))
|
|
1173
|
-
|
|
1258
|
+
while value.ndim < data.ndim:
|
|
1259
|
+
value = value.unsqueeze(-1)
|
|
1174
1260
|
value = F.broadcast_to(value, data.shape)
|
|
1175
1261
|
result = F.select(index, value, data)
|
|
1176
1262
|
return result
|
|
@@ -1184,13 +1270,12 @@ def tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
|
|
|
1184
1270
|
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
|
|
1185
1271
|
|
|
1186
1272
|
if F.is_sequence_value_unknown(F.shape(data)):
|
|
1187
|
-
|
|
1188
|
-
"Not supported to the dynamic shape tensor slice by using tensor of Boolean type")
|
|
1273
|
+
return tensor_setitem_by_tuple_with_tensor(data, (index,), value_tensor.astype(data.dtype))
|
|
1189
1274
|
return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
|
|
1190
1275
|
|
|
1191
1276
|
|
|
1192
1277
|
def tensor_setitem_by_tensor_with_number(data, index, value):
|
|
1193
|
-
value = F.
|
|
1278
|
+
value = F.cast(value, F.dtype(data))
|
|
1194
1279
|
return tensor_setitem_by_tensor_with_tensor(data, index, value)
|
|
1195
1280
|
|
|
1196
1281
|
|
|
@@ -1221,13 +1306,13 @@ def _tensor_setitem_by_bool_tensor_with_sequence(data, index, value):
|
|
|
1221
1306
|
|
|
1222
1307
|
def tensor_setitem_by_slice_with_number(data, input_slice, value):
|
|
1223
1308
|
"""Givens a scalar assign to tensor by slice"""
|
|
1224
|
-
value = F.
|
|
1309
|
+
value = F.cast(value, F.dtype(data))
|
|
1225
1310
|
return tensor_setitem_by_slice_with_tensor(data, input_slice, value)
|
|
1226
1311
|
|
|
1227
1312
|
|
|
1228
1313
|
def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
|
|
1229
1314
|
"""Assigns the tensor by tuple with number value."""
|
|
1230
|
-
value = F.
|
|
1315
|
+
value = F.cast(value, F.dtype(data))
|
|
1231
1316
|
return tensor_setitem_by_tuple_with_tensor(data, tuple_index, value)
|
|
1232
1317
|
|
|
1233
1318
|
|
|
@@ -1305,7 +1390,123 @@ def tensor_copy_slice_from_tuple(data, tuple_index, value):
|
|
|
1305
1390
|
return copy_slice(data, value.astype(data.dtype), start_tensor, stop_tensor, step_tensor)
|
|
1306
1391
|
|
|
1307
1392
|
|
|
1393
|
+
class _PreSetitemByTuple(base.PreSetitemByTuple_):
|
|
1394
|
+
"""
|
|
1395
|
+
Getting item of Tensor.
|
|
1396
|
+
|
|
1397
|
+
Args:
|
|
1398
|
+
data (Tensor): A tuple to be sliced.
|
|
1399
|
+
index: Index of tensor.
|
|
1400
|
+
|
|
1401
|
+
Returns:
|
|
1402
|
+
Type is the same as the element type of data.
|
|
1403
|
+
"""
|
|
1404
|
+
|
|
1405
|
+
def __init__(self, name):
|
|
1406
|
+
"""Initialize _PreSetitemByTuple."""
|
|
1407
|
+
base.PreSetitemByTuple_.__init__(self, name)
|
|
1408
|
+
|
|
1409
|
+
def __call__(self, *args):
|
|
1410
|
+
pass
|
|
1411
|
+
|
|
1412
|
+
|
|
1413
|
+
_pre_setitem_by_tuple = _PreSetitemByTuple('pre_setitem_by_tuple')
|
|
1414
|
+
|
|
1415
|
+
|
|
1416
|
+
class _HandleBoolTensor(base.HandleBoolTensor_):
|
|
1417
|
+
"""
|
|
1418
|
+
Getting item of Tensor.
|
|
1419
|
+
|
|
1420
|
+
Args:
|
|
1421
|
+
data (Tensor): A tuple to be sliced.
|
|
1422
|
+
index: Index of tensor.
|
|
1423
|
+
|
|
1424
|
+
Returns:
|
|
1425
|
+
Type is the same as the element type of data.
|
|
1426
|
+
"""
|
|
1427
|
+
|
|
1428
|
+
def __init__(self, name):
|
|
1429
|
+
"""Initialize _HandleBoolTensor."""
|
|
1430
|
+
base.HandleBoolTensor_.__init__(self, name)
|
|
1431
|
+
|
|
1432
|
+
def __call__(self, *args):
|
|
1433
|
+
pass
|
|
1434
|
+
|
|
1435
|
+
|
|
1436
|
+
_handle_bool_tensor = _HandleBoolTensor('handle_bool_tensor')
|
|
1437
|
+
|
|
1438
|
+
|
|
1439
|
+
class _HandleScalarTensorIndex(base.HandleScalarTensorIndex_):
|
|
1440
|
+
"""
|
|
1441
|
+
Getting item of Tensor.
|
|
1442
|
+
|
|
1443
|
+
Args:
|
|
1444
|
+
data (Tensor): A tuple to be sliced.
|
|
1445
|
+
index: Index of tensor.
|
|
1446
|
+
|
|
1447
|
+
Returns:
|
|
1448
|
+
Type is the same as the element type of data.
|
|
1449
|
+
"""
|
|
1450
|
+
|
|
1451
|
+
def __init__(self, name):
|
|
1452
|
+
"""Initialize _HandleBoolTensor."""
|
|
1453
|
+
base.HandleScalarTensorIndex_.__init__(self, name)
|
|
1454
|
+
|
|
1455
|
+
def __call__(self, *args):
|
|
1456
|
+
pass
|
|
1457
|
+
|
|
1458
|
+
|
|
1459
|
+
_handle_scalar_tensor_index = _HandleScalarTensorIndex('handle_scalar_tensor_index')
|
|
1460
|
+
|
|
1461
|
+
|
|
1308
1462
|
def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
|
1463
|
+
"""Assigns the tensor by tuple with tensor value."""
|
|
1464
|
+
if const_utils.use_copy_slice(tuple_index) and not const_utils.is_ascend():
|
|
1465
|
+
if F.is_sequence_value_unknown(F.shape(data)):
|
|
1466
|
+
return tensor_copy_slice_from_tuple(data, tuple_index, value)
|
|
1467
|
+
dim1_start, dim1_stop, _ = const_utils.normalize_slice(
|
|
1468
|
+
tuple_index[1], data.shape[1])
|
|
1469
|
+
if dim1_stop - dim1_start <= 0:
|
|
1470
|
+
return data
|
|
1471
|
+
dim0_start = tuple_index[0] if tuple_index[0] >= 0 else tuple_index[0] + data.shape[0]
|
|
1472
|
+
start = (dim0_start, dim1_start)
|
|
1473
|
+
stop = (dim0_start + 1, dim1_stop)
|
|
1474
|
+
step = (1, 1)
|
|
1475
|
+
value_shape = (dim1_stop - dim1_start,) + \
|
|
1476
|
+
const_utils.tuple_slice(data.shape, 2, None)
|
|
1477
|
+
value = _broadcast(value_shape, value)
|
|
1478
|
+
return copy_slice(data, value.astype(data.dtype), start, stop, step)
|
|
1479
|
+
tuple_index, _, non_zero_shapes = _handle_bool_tensor(tuple_index)
|
|
1480
|
+
|
|
1481
|
+
for non_zero_shape in non_zero_shapes:
|
|
1482
|
+
if F.reduce_min(non_zero_shape) == 0:
|
|
1483
|
+
return data
|
|
1484
|
+
value = value.astype(data.dtype)
|
|
1485
|
+
special_index, tuple_index, new_value_shape, idx_advanced, _broadcast_data_shape \
|
|
1486
|
+
= _pre_setitem_by_tuple(data, tuple_index, value)
|
|
1487
|
+
if special_index == 0:
|
|
1488
|
+
return data
|
|
1489
|
+
value = F.reshape(value, new_value_shape)
|
|
1490
|
+
if not tuple_index or special_index == 1:
|
|
1491
|
+
data[True] = value
|
|
1492
|
+
return data
|
|
1493
|
+
|
|
1494
|
+
empty_broadcast_data_shape = False
|
|
1495
|
+
if isinstance(_broadcast_data_shape, Tensor) and _broadcast_data_shape == Tensor([0]):
|
|
1496
|
+
empty_broadcast_data_shape = True
|
|
1497
|
+
if isinstance(_broadcast_data_shape, tuple) and not _broadcast_data_shape:
|
|
1498
|
+
empty_broadcast_data_shape = True
|
|
1499
|
+
indices = _tensor_index_setitem(
|
|
1500
|
+
data, tuple_index, value, idx_advanced, empty_broadcast_data_shape)
|
|
1501
|
+
|
|
1502
|
+
updates = _generate_updates_from_tensor(
|
|
1503
|
+
data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
|
|
1504
|
+
if is_parameter(data):
|
|
1505
|
+
F.scatter_nd_update(data, indices, updates)
|
|
1506
|
+
return data
|
|
1507
|
+
return F.tensor_scatter_update(data, indices, updates)
|
|
1508
|
+
|
|
1509
|
+
def tensor_itemset_by_tuple_with_tensor(data, tuple_index, value):
|
|
1309
1510
|
"""Assigns the tensor by tuple with tensor value."""
|
|
1310
1511
|
op_name = const_utils.TENSOR_SETITEM
|
|
1311
1512
|
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
|
@@ -1323,7 +1524,6 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
|
|
1323
1524
|
value_shape = (dim1_stop - dim1_start,) + const_utils.tuple_slice(data.shape, 2, None)
|
|
1324
1525
|
value = _broadcast(value_shape, value)
|
|
1325
1526
|
return copy_slice(data, value.astype(data.dtype), start, stop, step)
|
|
1326
|
-
|
|
1327
1527
|
tuple_index, value, idx_advanced = remove_expanded_dims(tuple_index, F.shape(data), value)
|
|
1328
1528
|
|
|
1329
1529
|
if tuple_index is False:
|
|
@@ -1351,7 +1551,7 @@ def tensor_setitem_by_tuple_with_sequence(data, tuple_index, value):
|
|
|
1351
1551
|
|
|
1352
1552
|
def tensor_setitem_by_number_with_number(data, index, value):
|
|
1353
1553
|
"""Assigns the tensor by number with number value."""
|
|
1354
|
-
value = F.
|
|
1554
|
+
value = F.cast(value, F.dtype(data))
|
|
1355
1555
|
return tensor_setitem_by_number_with_tensor(data, index, value)
|
|
1356
1556
|
|
|
1357
1557
|
|
|
@@ -1386,7 +1586,7 @@ def tensor_setitem_by_ellipsis_with_number(data, value):
|
|
|
1386
1586
|
data_shape = F.shape(data)
|
|
1387
1587
|
data_dtype = F.dtype(data)
|
|
1388
1588
|
if F.is_sequence_value_unknown(data_shape):
|
|
1389
|
-
value = F.
|
|
1589
|
+
value = F.cast(value, F.dtype(data))
|
|
1390
1590
|
return tensor_setitem_by_ellipsis_with_tensor(data, value)
|
|
1391
1591
|
return F.fill(data_dtype, data_shape, value)
|
|
1392
1592
|
|
|
@@ -1418,6 +1618,7 @@ def tensor_setitem_by_ellipsis_with_sequence(data, value):
|
|
|
1418
1618
|
def tensor_setitem_by_bool(data, index, value):
|
|
1419
1619
|
"""Assigns a value to the tensor by boolean."""
|
|
1420
1620
|
data_shape = F.shape(data)
|
|
1621
|
+
data_dtype = F.dtype(data)
|
|
1421
1622
|
if not index:
|
|
1422
1623
|
data_shape = (0,) + data_shape
|
|
1423
1624
|
if isinstance(value, (list, tuple)):
|
|
@@ -1429,6 +1630,7 @@ def tensor_setitem_by_bool(data, index, value):
|
|
|
1429
1630
|
|
|
1430
1631
|
if F.is_sequence_value_unknown(data_shape) and index:
|
|
1431
1632
|
data_shape = F.dyn_shape(data)
|
|
1633
|
+
value = value.astype(data_dtype)
|
|
1432
1634
|
data = ops.broadcast_to(value, data_shape)
|
|
1433
1635
|
return data
|
|
1434
1636
|
value_shape = F.shape(value)
|
|
@@ -1436,7 +1638,7 @@ def tensor_setitem_by_bool(data, index, value):
|
|
|
1436
1638
|
if index:
|
|
1437
1639
|
value = F.reshape(value, source_shape)
|
|
1438
1640
|
value = _broadcast(data_shape, value)
|
|
1439
|
-
data = value
|
|
1641
|
+
data = F.cast(value, data_dtype)
|
|
1440
1642
|
return data
|
|
1441
1643
|
|
|
1442
1644
|
|