mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +26 -32
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +12 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +61 -71
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +72 -95
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +13 -0
- mindspore/common/api.py +173 -258
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +240 -145
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +13 -2
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +143 -59
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +28 -5
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +11 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +59 -66
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +0 -14
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +316 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +21 -28
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +310 -207
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +82 -41
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +13 -18
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +22 -17
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +78 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +1 -2
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +10 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +4 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +273 -72
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +40 -2
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +167 -189
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -8
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +470 -251
- mindspore/ops/function/random_func.py +86 -56
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +235 -19
- mindspore/ops/operations/__init__.py +25 -17
- mindspore/ops/operations/_grad_ops.py +52 -7
- mindspore/ops/operations/_inner_ops.py +213 -12
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +64 -280
- mindspore/ops/operations/comm_ops.py +105 -57
- mindspore/ops/operations/custom_ops.py +10 -3
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/math_ops.py +185 -138
- mindspore/ops/operations/nn_ops.py +716 -492
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +14 -12
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +6 -10
- mindspore/parallel/shard.py +4 -4
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +17 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +104 -252
- mindspore/profiler/parser/ascend_msprof_generator.py +8 -8
- mindspore/profiler/parser/ascend_op_generator.py +5 -5
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +9 -6
- mindspore/profiler/parser/base_timeline_generator.py +9 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +14 -10
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +37 -21
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +2 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +139 -71
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +525 -577
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +2 -2
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +14 -7
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +83 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +185 -45
- mindspore/train/serialization.py +390 -150
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +14 -10
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/METADATA +6 -7
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/RECORD +458 -518
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
|
@@ -24,7 +24,8 @@ from mindspore.ops import operations as P
|
|
|
24
24
|
from mindspore.ops.composite import base
|
|
25
25
|
from mindspore.ops._primitive_cache import _get_cache_prim
|
|
26
26
|
from mindspore.ops.operations._inner_ops import TensorCopySlices, SliceGetItem, \
|
|
27
|
-
TopTypeof, issubclass_, IsParameter, GetitemTensorIndexInfo, SetitemTensorIndexInfo
|
|
27
|
+
TopTypeof, issubclass_, IsParameter, GetitemTensorIndexInfo, SetitemTensorIndexInfo, \
|
|
28
|
+
SelectView, CopyWithSlice
|
|
28
29
|
from mindspore.common import dtype as mstype
|
|
29
30
|
from mindspore.common._register_for_tensor import tensor_operator_registry
|
|
30
31
|
from mindspore.common.initializer import Zero
|
|
@@ -33,6 +34,7 @@ from mindspore.common import mutable
|
|
|
33
34
|
from mindspore import ops
|
|
34
35
|
from mindspore.ops.primitive import _primexpr
|
|
35
36
|
from mindspore import _checkparam as validator
|
|
37
|
+
from mindspore.common._stub_tensor import _convert_stub
|
|
36
38
|
|
|
37
39
|
slice_get_item = SliceGetItem()
|
|
38
40
|
hyper_map = base.HyperMap()
|
|
@@ -43,6 +45,8 @@ is_parameter = IsParameter()
|
|
|
43
45
|
getitem_tensor_index_info = GetitemTensorIndexInfo(const_utils.is_ascend())
|
|
44
46
|
setitem_tensor_index_info = SetitemTensorIndexInfo(const_utils.is_ascend())
|
|
45
47
|
|
|
48
|
+
selevt_view = SelectView()
|
|
49
|
+
copy_with_slice = CopyWithSlice()
|
|
46
50
|
|
|
47
51
|
def strided_slice(data, begin_strides, end_strides, step_strides, begin_mask=0, end_mask=0, ellipsis_mask=0,
|
|
48
52
|
new_axis_mask=0, shrink_axis_mask=0):
|
|
@@ -66,19 +70,23 @@ class ValueTransferType(IntEnum):
|
|
|
66
70
|
kGatherND = 9
|
|
67
71
|
kScatterNdUpdate = 10
|
|
68
72
|
kReshape = 11
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
73
|
+
kSelectView = 12
|
|
74
|
+
kUnsqueeze = 13
|
|
75
|
+
kCopyView = 14
|
|
76
|
+
kScatterND = 15
|
|
77
|
+
kNumberToTensor = 16
|
|
78
|
+
kHandleSequenceValue = 17
|
|
79
|
+
kByPass = 18
|
|
80
|
+
kReSetItemByIndex = 19
|
|
81
|
+
kCopySlice = 20
|
|
82
|
+
kSetItemByBool = 21
|
|
83
|
+
kEmptyTensor = 22
|
|
84
|
+
kSetItemByEllipsis = 23
|
|
85
|
+
kFormatIndexTensor = 24
|
|
86
|
+
kGetitemByBoolTensor = 25
|
|
87
|
+
kSetitemByBoolTensor = 26
|
|
88
|
+
kJustReturn = 27
|
|
89
|
+
kRaiseIndexError = 28
|
|
82
90
|
|
|
83
91
|
|
|
84
92
|
def data_update(transfer_types, args, data, new_index, value=None):
|
|
@@ -86,11 +94,14 @@ def data_update(transfer_types, args, data, new_index, value=None):
|
|
|
86
94
|
We finally generate a new tensor when handling tensor getitem/setitem
|
|
87
95
|
by transfer data and value with index.
|
|
88
96
|
"""
|
|
97
|
+
origin_data = data
|
|
89
98
|
for transfer_type, arg in zip(transfer_types, args):
|
|
90
99
|
if transfer_type == ValueTransferType.kUnknown:
|
|
91
100
|
raise IndexError(f"Inlvaid transfer type {transfer_type}.")
|
|
92
101
|
if transfer_type <= ValueTransferType.kScatterND:
|
|
93
|
-
data = data_update_by_ops(transfer_type, arg, data, new_index, value)
|
|
102
|
+
data = data_update_by_ops(transfer_type, arg, data, new_index, origin_data, value)
|
|
103
|
+
if transfer_type == ValueTransferType.kJustReturn:
|
|
104
|
+
return _convert_stub(arg)
|
|
94
105
|
if transfer_type == ValueTransferType.kSetItemByBool:
|
|
95
106
|
return tensor_setitem_by_bool(data, new_index, value)
|
|
96
107
|
if transfer_type == ValueTransferType.kCopySlice:
|
|
@@ -114,7 +125,7 @@ def data_update(transfer_types, args, data, new_index, value=None):
|
|
|
114
125
|
return data
|
|
115
126
|
|
|
116
127
|
|
|
117
|
-
def data_update_by_ops(transfer_type, arg, data, new_index, value=None):
|
|
128
|
+
def data_update_by_ops(transfer_type, arg, data, new_index, origin_data, value=None):
|
|
118
129
|
"""
|
|
119
130
|
Generate a new tensor when handling tensor getitem/setitem
|
|
120
131
|
by ops.
|
|
@@ -135,14 +146,22 @@ def data_update_by_ops(transfer_type, arg, data, new_index, value=None):
|
|
|
135
146
|
F.scatter_nd_update(data, new_index, value)
|
|
136
147
|
elif transfer_type == ValueTransferType.kSelect:
|
|
137
148
|
data = F.select(Tensor(new_index), value, data)
|
|
149
|
+
elif transfer_type == ValueTransferType.kSelectView:
|
|
150
|
+
data = selevt_view(data, arg[0], arg[1])
|
|
151
|
+
elif transfer_type == ValueTransferType.kCopyView:
|
|
152
|
+
value = _broadcast(F.shape(data), F.cast(value, F.dtype(data)))
|
|
153
|
+
data = copy_with_slice(data, value)
|
|
154
|
+
return origin_data
|
|
138
155
|
elif transfer_type == ValueTransferType.kReshape:
|
|
139
156
|
data = F.reshape(data, arg)
|
|
140
157
|
elif transfer_type == ValueTransferType.kGather:
|
|
141
158
|
data = F.gather(data, new_index, 0)
|
|
142
159
|
elif transfer_type == ValueTransferType.kExpandDims:
|
|
143
160
|
data = F.expand_dims(data, 0)
|
|
161
|
+
elif transfer_type == ValueTransferType.kUnsqueeze:
|
|
162
|
+
data = F.unsqueeze(data, arg)
|
|
144
163
|
elif transfer_type == ValueTransferType.kStrideSlice:
|
|
145
|
-
data =
|
|
164
|
+
data = strided_slice(data, arg[0], arg[1], arg[2])
|
|
146
165
|
else:
|
|
147
166
|
raise IndexError(f"Inlvaid transfer type {transfer_type}.")
|
|
148
167
|
return data
|
|
@@ -154,7 +173,7 @@ def value_update(transfer_types, args, data, value):
|
|
|
154
173
|
if transfer_type == ValueTransferType.kByPass:
|
|
155
174
|
continue
|
|
156
175
|
if transfer_type == ValueTransferType.kNumberToTensor:
|
|
157
|
-
value = F.
|
|
176
|
+
value = F.cast(value, F.dtype(data))
|
|
158
177
|
elif transfer_type == ValueTransferType.kHandleSequenceValue:
|
|
159
178
|
op_type, index = arg
|
|
160
179
|
if op_type == const_utils.SET_ITEM_BY_ONE_TENSOR:
|
|
@@ -192,7 +211,10 @@ def _tensor_setitem(self, index, value):
|
|
|
192
211
|
data_update_types = setitem_info[3]
|
|
193
212
|
data_update_args = setitem_info[4]
|
|
194
213
|
value = value_update(v_transfer_types, v_transfer_args, self, value)
|
|
195
|
-
|
|
214
|
+
output = data_update(data_update_types, data_update_args, self, new_index, value)
|
|
215
|
+
if new_index == "view":
|
|
216
|
+
return (self,)
|
|
217
|
+
return output
|
|
196
218
|
|
|
197
219
|
|
|
198
220
|
tensor_operator_registry.register("__getitem__", _tensor_getitem)
|
|
@@ -286,7 +308,7 @@ def _scalar_to_tensor(input_x):
|
|
|
286
308
|
@_primexpr
|
|
287
309
|
def _check_scalar_tensor_args(args):
|
|
288
310
|
"""For the item, check that the index of the scalar tensor is set."""
|
|
289
|
-
if args
|
|
311
|
+
if args not in ((None,), ()):
|
|
290
312
|
const_utils.raise_value_error("For item, the index of scalar Tensor should not be set.")
|
|
291
313
|
|
|
292
314
|
|
|
@@ -295,15 +317,15 @@ def tensor_item(data, *args):
|
|
|
295
317
|
# transform a.item(tuple(int)) -> a.item(int1,int2...intN)
|
|
296
318
|
if data.ndim == 0:
|
|
297
319
|
_check_scalar_tensor_args(args)
|
|
298
|
-
return data
|
|
320
|
+
return data.asnumpy().item()
|
|
299
321
|
if len(args) == 1 and isinstance(args[0], tuple):
|
|
300
322
|
args = args[0]
|
|
301
323
|
|
|
302
324
|
args_types = hyper_map(F.typeof, args)
|
|
303
325
|
if not args or const_utils.judge_index_type(args_types[0], mstype.type_none):
|
|
304
326
|
if data.shape == (1,):
|
|
305
|
-
return data
|
|
306
|
-
const_utils.raise_value_error("Can only convert an array of size 1 to a
|
|
327
|
+
return data.asnumpy().item()
|
|
328
|
+
const_utils.raise_value_error("Can only convert an array of size 1 to a Python scalar")
|
|
307
329
|
|
|
308
330
|
if not const_utils.judge_indexes_types(args_types, mstype.int64):
|
|
309
331
|
const_utils.raise_type_error("The index object cannot be interpreted as an integer")
|
|
@@ -362,7 +384,8 @@ def tensor_itemset_by_tuple_with_number(data, tuple_index, nubmer_value):
|
|
|
362
384
|
exp_msg = const_utils.gen_exception_msg(
|
|
363
385
|
"Tuple index len({}) is not same to tensor dimension({})", len(tuple_index), data.ndim)
|
|
364
386
|
const_utils.raise_index_error(exp_msg)
|
|
365
|
-
|
|
387
|
+
nubmer_value = F.cast(nubmer_value, F.dtype(data))
|
|
388
|
+
return tensor_itemset_by_tuple_with_tensor(data, tuple_index, nubmer_value)
|
|
366
389
|
|
|
367
390
|
|
|
368
391
|
def _broadcast(broadcast_shape, x):
|
|
@@ -530,10 +553,6 @@ class _TensorIndexGetitem(base.TensorIndexGetitem_):
|
|
|
530
553
|
Type is the same as the element type of data.
|
|
531
554
|
"""
|
|
532
555
|
|
|
533
|
-
def __init__(self, name):
|
|
534
|
-
"""Initialize _TensorIndexGetitem."""
|
|
535
|
-
base.TensorIndexGetitem_.__init__(self, name)
|
|
536
|
-
|
|
537
556
|
def __call__(self, *args):
|
|
538
557
|
pass
|
|
539
558
|
|
|
@@ -580,9 +599,12 @@ def _tensor_index_by_bool(data, bool_value):
|
|
|
580
599
|
"""Tensor getitem by a single bool value"""
|
|
581
600
|
min_data_dim, max_data_dim = 0, 7
|
|
582
601
|
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
602
|
+
output = data
|
|
583
603
|
if bool_value:
|
|
584
|
-
|
|
585
|
-
|
|
604
|
+
output = F.expand_dims(data, 0)
|
|
605
|
+
elif not F.is_sequence_value_unknown(F.shape(data)):
|
|
606
|
+
return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.")
|
|
607
|
+
return output
|
|
586
608
|
|
|
587
609
|
|
|
588
610
|
def get_stride_info_from_integer(tensor_int):
|
|
@@ -599,15 +621,14 @@ def get_stride_info_from_integer(tensor_int):
|
|
|
599
621
|
def _tensor_index_by_integer(data, int_index):
|
|
600
622
|
"""Tensor getitem by a single integer number"""
|
|
601
623
|
data_shape = F.shape(data)
|
|
602
|
-
if not data_shape:
|
|
603
|
-
const_utils.raise_type_error("Cannot iterate over a scalar tensor.")
|
|
604
|
-
if data.ndim < 1 or data.ndim > 8:
|
|
605
|
-
const_utils.raise_value_error("Expect Tensor to have dimension between 1 and 8.")
|
|
606
|
-
|
|
607
624
|
if F.is_sequence_value_unknown(data_shape) or not F.isconstant(int_index):
|
|
608
625
|
tensor_index = _scalar_to_tensor(int_index)
|
|
609
626
|
begin_strides, end_strides, step_strides = get_stride_info_from_integer(tensor_index)
|
|
610
627
|
else:
|
|
628
|
+
if not data_shape:
|
|
629
|
+
const_utils.raise_type_error("Cannot iterate over a scalar tensor.")
|
|
630
|
+
if data.ndim < 1 or data.ndim > 8:
|
|
631
|
+
const_utils.raise_value_error("Expect Tensor to have dimension between 1 and 8.")
|
|
611
632
|
transformed_number = const_utils.check_range(int_index, data_shape[0])
|
|
612
633
|
begin_strides, end_strides, step_strides = \
|
|
613
634
|
const_utils.get_stride_info_from_integer(data_shape, transformed_number)
|
|
@@ -619,7 +640,6 @@ def _tensor_index_by_integer(data, int_index):
|
|
|
619
640
|
end_mask += 2 ** i
|
|
620
641
|
return strided_slice(data, begin_strides, end_strides, step_strides, begin_mask, end_mask, 0, 0, shrink_axis_mask)
|
|
621
642
|
|
|
622
|
-
|
|
623
643
|
def _check_dim_shape_valid(data, tensor_index):
|
|
624
644
|
"""check dim and shape of tensor_index for tensor(bool) indexing"""
|
|
625
645
|
if data.ndim < tensor_index.ndim:
|
|
@@ -632,7 +652,8 @@ def _check_dim_shape_valid(data, tensor_index):
|
|
|
632
652
|
|
|
633
653
|
def tensor_index_by_bool_tensor(data, tensor_index):
|
|
634
654
|
"""Tensor getitem by a bool tensor"""
|
|
635
|
-
|
|
655
|
+
if not F.is_sequence_value_unknown(F.shape(data)):
|
|
656
|
+
_check_dim_shape_valid(data, tensor_index)
|
|
636
657
|
tensor_index = tensor_index.nonzero()
|
|
637
658
|
return F.gather_nd(data, tensor_index)
|
|
638
659
|
|
|
@@ -640,7 +661,8 @@ def tensor_index_by_bool_tensor(data, tensor_index):
|
|
|
640
661
|
def tensor_index_by_tensor(data, tensor_index):
|
|
641
662
|
"""Tensor getitem by a single tensor"""
|
|
642
663
|
min_data_dim, max_data_dim = 0, 7
|
|
643
|
-
|
|
664
|
+
if not F.is_sequence_value_unknown(F.shape(data)):
|
|
665
|
+
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
644
666
|
if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Int):
|
|
645
667
|
return F.gather(data, tensor_index, 0)
|
|
646
668
|
if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Bool):
|
|
@@ -658,16 +680,22 @@ def tensor_index_by_list(data, list_index):
|
|
|
658
680
|
|
|
659
681
|
data_shape = F.shape(data)
|
|
660
682
|
indexes_types = hyper_map(toptypeof, list_index)
|
|
661
|
-
if const_utils.check_type_isinstance(indexes_types, (mstype.Bool, mstype.Int))
|
|
683
|
+
if const_utils.check_type_isinstance(indexes_types, (mstype.Bool, mstype.Int)) \
|
|
684
|
+
and not F.is_sequence_value_unknown(list_index):
|
|
662
685
|
if not F.isconstant(data_shape[0]):
|
|
663
686
|
if all(isinstance(i, bool) for i in list_index):
|
|
664
|
-
|
|
665
|
-
|
|
687
|
+
if F.dyn_shape(data)[0] != len(list_index):
|
|
688
|
+
raise IndexError(
|
|
689
|
+
f'dimension is {F.dyn_shape(data)[0]} but corresponding boolean dimension is {len(list_index)}')
|
|
690
|
+
tensor_index = Tensor(list_index).nonzero()
|
|
691
|
+
return F.gather_nd(data, tensor_index)
|
|
666
692
|
tensor_index = const_utils.sequence_to_index(list_index, None)
|
|
667
693
|
else:
|
|
668
|
-
tensor_index = const_utils.sequence_to_index(
|
|
694
|
+
tensor_index = const_utils.sequence_to_index(
|
|
695
|
+
list_index, data_shape[0])
|
|
669
696
|
if tensor_index is False:
|
|
670
|
-
const_utils.raise_index_error(
|
|
697
|
+
const_utils.raise_index_error(
|
|
698
|
+
"When tensor is indexed by list, the list can't be empty.")
|
|
671
699
|
return F.gather(data, tensor_index, 0)
|
|
672
700
|
|
|
673
701
|
tuple_index_new = ()
|
|
@@ -693,6 +721,29 @@ def judge_tuple_index_dim_check_error(index_dim, data_dim):
|
|
|
693
721
|
f"dim of index:{index_dim}, dim of data:{data_dim}")
|
|
694
722
|
|
|
695
723
|
|
|
724
|
+
class _HandleEmptySlice(base.HandleEmptySlice_):
|
|
725
|
+
"""
|
|
726
|
+
Getting item of Tensor.
|
|
727
|
+
|
|
728
|
+
Args:
|
|
729
|
+
data (Tensor): A tuple to be sliced.
|
|
730
|
+
index: Index of tensor.
|
|
731
|
+
|
|
732
|
+
Returns:
|
|
733
|
+
Type is the same as the element type of data.
|
|
734
|
+
"""
|
|
735
|
+
|
|
736
|
+
def __init__(self, name):
|
|
737
|
+
"""Initialize _HandleEmptySlice."""
|
|
738
|
+
base.HandleEmptySlice_.__init__(self, name)
|
|
739
|
+
|
|
740
|
+
def __call__(self, *args):
|
|
741
|
+
pass
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
_handle_empty_slice = _HandleEmptySlice('handle_zero_tuple_index')
|
|
745
|
+
|
|
746
|
+
|
|
696
747
|
def judge_tuple_index_dim(data, tuple_index):
|
|
697
748
|
"""Judge whether tuple_index's dim is valid"""
|
|
698
749
|
data_dim = data.ndim
|
|
@@ -700,29 +751,55 @@ def judge_tuple_index_dim(data, tuple_index):
|
|
|
700
751
|
for index in tuple_index:
|
|
701
752
|
if isinstance(toptypeof(index), mstype.TensorType) and index.dtype == mstype.bool_:
|
|
702
753
|
index_dim += index.ndim
|
|
703
|
-
|
|
754
|
+
elif not isinstance(toptypeof(index), (mstype.NoneType, mstype.Ellipsis_, mstype.Bool)):
|
|
704
755
|
index_dim += 1
|
|
705
756
|
judge_tuple_index_dim_check_error(index_dim, data_dim)
|
|
706
757
|
|
|
707
758
|
|
|
759
|
+
def judge_simple_tuple_index(data, tuple_index):
|
|
760
|
+
"""Judge whether tuple_index is simple index, which not rollback to cpu ops."""
|
|
761
|
+
op_name = const_utils.TENSOR_GETITEM
|
|
762
|
+
indexes_types = hyper_map(toptypeof, tuple_index)
|
|
763
|
+
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
|
|
764
|
+
return F.isconstant(tuple_index) and contain_type == const_utils.ALL_BASIC \
|
|
765
|
+
and F.is_sequence_value_unknown(F.shape(data)) and F.isconstant(F.rank(data))
|
|
766
|
+
|
|
767
|
+
|
|
708
768
|
def tensor_index_by_tuple(data, tuple_index):
|
|
709
769
|
"""Tensor getitem by tuple of various types with None"""
|
|
710
770
|
if not tuple_index:
|
|
711
771
|
return data
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
min_data_dim, max_data_dim = 1, 8
|
|
719
|
-
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
720
|
-
judge_tuple_index_dim(data, tuple_index)
|
|
721
|
-
indexes_types = hyper_map(toptypeof, tuple_index)
|
|
722
|
-
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
|
|
723
|
-
if contain_type == const_utils.ALL_BASIC:
|
|
772
|
+
if judge_simple_tuple_index(data, tuple_index):
|
|
773
|
+
tuple_index = convert_tupleslice_to_tensor(tuple_index)
|
|
774
|
+
op_name = const_utils.TENSOR_GETITEM
|
|
775
|
+
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
|
776
|
+
min_data_dim, max_data_dim = 1, 8
|
|
777
|
+
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
724
778
|
return _tensor_getitem_by_tuple_slice(data, tuple_index)
|
|
725
|
-
|
|
779
|
+
|
|
780
|
+
if not F.is_sequence_value_unknown(F.shape(data)):
|
|
781
|
+
judge_tuple_index_dim(data, tuple_index)
|
|
782
|
+
tuple_index, zero_index, non_zero_shapes = _handle_bool_tensor(tuple_index)
|
|
783
|
+
for non_zero_shape in non_zero_shapes:
|
|
784
|
+
if F.reduce_min(non_zero_shape) == 0:
|
|
785
|
+
tuple_index = zero_index
|
|
786
|
+
break
|
|
787
|
+
if not F.is_sequence_value_unknown(F.shape(data)) and F.isconstant(tuple_index):
|
|
788
|
+
_, stub_zero_dim_tensor = _handle_empty_slice(data, tuple_index)
|
|
789
|
+
if 0 in stub_zero_dim_tensor.shape:
|
|
790
|
+
return F.fill(data.dtype, stub_zero_dim_tensor.shape, 0)
|
|
791
|
+
has_tensor_index = False
|
|
792
|
+
for i in tuple_index:
|
|
793
|
+
if isinstance(i, Tensor):
|
|
794
|
+
has_tensor_index = True
|
|
795
|
+
break
|
|
796
|
+
empty_broadcast_data_shape = False
|
|
797
|
+
_broadcast_data_shape = _handle_scalar_tensor_index(data, tuple_index)
|
|
798
|
+
if has_tensor_index and isinstance(_broadcast_data_shape, Tensor) and _broadcast_data_shape == Tensor([0]):
|
|
799
|
+
empty_broadcast_data_shape = True
|
|
800
|
+
if has_tensor_index and isinstance(_broadcast_data_shape, tuple) and not _broadcast_data_shape:
|
|
801
|
+
empty_broadcast_data_shape = True
|
|
802
|
+
return _tensor_index_getitem(data, tuple_index, empty_broadcast_data_shape)
|
|
726
803
|
|
|
727
804
|
|
|
728
805
|
def get_slice_stride(slice_index, dim_size):
|
|
@@ -1039,7 +1116,7 @@ def sequence_to_tensor(value, dtype):
|
|
|
1039
1116
|
|
|
1040
1117
|
if value_elements_type == const_utils.ALL_TENSOR:
|
|
1041
1118
|
value = F.stack(value).astype(dtype)
|
|
1042
|
-
elif value_elements_type == const_utils.NO_TENSOR:
|
|
1119
|
+
elif value_elements_type == const_utils.NO_TENSOR and not F.is_sequence_value_unknown(value):
|
|
1043
1120
|
value = const_utils.make_tensor(value, dtype)
|
|
1044
1121
|
else:
|
|
1045
1122
|
new_value = ()
|
|
@@ -1061,7 +1138,7 @@ def _generate_updates_from_sequence(data, index, value, op_type):
|
|
|
1061
1138
|
def _generate_updates_from_tensor(data, index, value, op_type):
|
|
1062
1139
|
"""Generate an updates tensor from a tensor."""
|
|
1063
1140
|
value = value.astype(data.dtype)
|
|
1064
|
-
if F.is_sequence_value_unknown(F.shape(data)):
|
|
1141
|
+
if F.is_sequence_value_unknown(F.shape(data)) or F.is_sequence_value_unknown(F.shape(index)):
|
|
1065
1142
|
data_shape = F.dyn_shape(data)
|
|
1066
1143
|
index_shape = F.dyn_shape(index)
|
|
1067
1144
|
updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type, True)
|
|
@@ -1102,6 +1179,18 @@ def tensor_setitem_by_number(self, index, value):
|
|
|
1102
1179
|
return tensor_setitem_by_number_with_sequence(self, index, value)
|
|
1103
1180
|
|
|
1104
1181
|
|
|
1182
|
+
def _tuple_index_transfer(broadcast_shape, final_shape, new_shape, x, all_empty_tensor):
|
|
1183
|
+
"""Transform tuple index tensor to the required."""
|
|
1184
|
+
if isinstance(broadcast_shape, Tensor):
|
|
1185
|
+
if not all_empty_tensor:
|
|
1186
|
+
x = F.broadcast_to(x, broadcast_shape)
|
|
1187
|
+
x = F.reshape(x, new_shape)
|
|
1188
|
+
x = F.broadcast_to(x, final_shape)
|
|
1189
|
+
return x
|
|
1190
|
+
item = _broadcast(broadcast_shape, x)
|
|
1191
|
+
return _broadcast(final_shape, F.reshape(item, new_shape))
|
|
1192
|
+
|
|
1193
|
+
|
|
1105
1194
|
class _TensorIndexSetitem(base.TensorIndexSetitem_):
|
|
1106
1195
|
"""
|
|
1107
1196
|
Getting item of Tensor.
|
|
@@ -1114,10 +1203,6 @@ class _TensorIndexSetitem(base.TensorIndexSetitem_):
|
|
|
1114
1203
|
Type is the same as the element type of data.
|
|
1115
1204
|
"""
|
|
1116
1205
|
|
|
1117
|
-
def __init__(self, name):
|
|
1118
|
-
"""Initialize _TensorIndexGetitem."""
|
|
1119
|
-
base.TensorIndexSetitem_.__init__(self, name)
|
|
1120
|
-
|
|
1121
1206
|
def __call__(self, *args):
|
|
1122
1207
|
pass
|
|
1123
1208
|
|
|
@@ -1184,13 +1269,12 @@ def tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
|
|
|
1184
1269
|
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
|
|
1185
1270
|
|
|
1186
1271
|
if F.is_sequence_value_unknown(F.shape(data)):
|
|
1187
|
-
|
|
1188
|
-
"Not supported to the dynamic shape tensor slice by using tensor of Boolean type")
|
|
1272
|
+
return tensor_setitem_by_tuple_with_tensor(data, (index,), value_tensor.astype(data.dtype))
|
|
1189
1273
|
return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
|
|
1190
1274
|
|
|
1191
1275
|
|
|
1192
1276
|
def tensor_setitem_by_tensor_with_number(data, index, value):
|
|
1193
|
-
value = F.
|
|
1277
|
+
value = F.cast(value, F.dtype(data))
|
|
1194
1278
|
return tensor_setitem_by_tensor_with_tensor(data, index, value)
|
|
1195
1279
|
|
|
1196
1280
|
|
|
@@ -1221,13 +1305,13 @@ def _tensor_setitem_by_bool_tensor_with_sequence(data, index, value):
|
|
|
1221
1305
|
|
|
1222
1306
|
def tensor_setitem_by_slice_with_number(data, input_slice, value):
|
|
1223
1307
|
"""Givens a scalar assign to tensor by slice"""
|
|
1224
|
-
value = F.
|
|
1308
|
+
value = F.cast(value, F.dtype(data))
|
|
1225
1309
|
return tensor_setitem_by_slice_with_tensor(data, input_slice, value)
|
|
1226
1310
|
|
|
1227
1311
|
|
|
1228
1312
|
def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
|
|
1229
1313
|
"""Assigns the tensor by tuple with number value."""
|
|
1230
|
-
value = F.
|
|
1314
|
+
value = F.cast(value, F.dtype(data))
|
|
1231
1315
|
return tensor_setitem_by_tuple_with_tensor(data, tuple_index, value)
|
|
1232
1316
|
|
|
1233
1317
|
|
|
@@ -1305,7 +1389,123 @@ def tensor_copy_slice_from_tuple(data, tuple_index, value):
|
|
|
1305
1389
|
return copy_slice(data, value.astype(data.dtype), start_tensor, stop_tensor, step_tensor)
|
|
1306
1390
|
|
|
1307
1391
|
|
|
1392
|
+
class _PreSetitemByTuple(base.PreSetitemByTuple_):
|
|
1393
|
+
"""
|
|
1394
|
+
Getting item of Tensor.
|
|
1395
|
+
|
|
1396
|
+
Args:
|
|
1397
|
+
data (Tensor): A tuple to be sliced.
|
|
1398
|
+
index: Index of tensor.
|
|
1399
|
+
|
|
1400
|
+
Returns:
|
|
1401
|
+
Type is the same as the element type of data.
|
|
1402
|
+
"""
|
|
1403
|
+
|
|
1404
|
+
def __init__(self, name):
|
|
1405
|
+
"""Initialize _PreSetitemByTuple."""
|
|
1406
|
+
base.PreSetitemByTuple_.__init__(self, name)
|
|
1407
|
+
|
|
1408
|
+
def __call__(self, *args):
|
|
1409
|
+
pass
|
|
1410
|
+
|
|
1411
|
+
|
|
1412
|
+
_pre_setitem_by_tuple = _PreSetitemByTuple('pre_setitem_by_tuple')
|
|
1413
|
+
|
|
1414
|
+
|
|
1415
|
+
class _HandleBoolTensor(base.HandleBoolTensor_):
|
|
1416
|
+
"""
|
|
1417
|
+
Getting item of Tensor.
|
|
1418
|
+
|
|
1419
|
+
Args:
|
|
1420
|
+
data (Tensor): A tuple to be sliced.
|
|
1421
|
+
index: Index of tensor.
|
|
1422
|
+
|
|
1423
|
+
Returns:
|
|
1424
|
+
Type is the same as the element type of data.
|
|
1425
|
+
"""
|
|
1426
|
+
|
|
1427
|
+
def __init__(self, name):
|
|
1428
|
+
"""Initialize _HandleBoolTensor."""
|
|
1429
|
+
base.HandleBoolTensor_.__init__(self, name)
|
|
1430
|
+
|
|
1431
|
+
def __call__(self, *args):
|
|
1432
|
+
pass
|
|
1433
|
+
|
|
1434
|
+
|
|
1435
|
+
_handle_bool_tensor = _HandleBoolTensor('handle_bool_tensor')
|
|
1436
|
+
|
|
1437
|
+
|
|
1438
|
+
class _HandleScalarTensorIndex(base.HandleScalarTensorIndex_):
|
|
1439
|
+
"""
|
|
1440
|
+
Getting item of Tensor.
|
|
1441
|
+
|
|
1442
|
+
Args:
|
|
1443
|
+
data (Tensor): A tuple to be sliced.
|
|
1444
|
+
index: Index of tensor.
|
|
1445
|
+
|
|
1446
|
+
Returns:
|
|
1447
|
+
Type is the same as the element type of data.
|
|
1448
|
+
"""
|
|
1449
|
+
|
|
1450
|
+
def __init__(self, name):
|
|
1451
|
+
"""Initialize _HandleBoolTensor."""
|
|
1452
|
+
base.HandleScalarTensorIndex_.__init__(self, name)
|
|
1453
|
+
|
|
1454
|
+
def __call__(self, *args):
|
|
1455
|
+
pass
|
|
1456
|
+
|
|
1457
|
+
|
|
1458
|
+
_handle_scalar_tensor_index = _HandleScalarTensorIndex('handle_scalar_tensor_index')
|
|
1459
|
+
|
|
1460
|
+
|
|
1308
1461
|
def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
|
1462
|
+
"""Assigns the tensor by tuple with tensor value."""
|
|
1463
|
+
if const_utils.use_copy_slice(tuple_index) and not const_utils.is_ascend():
|
|
1464
|
+
if F.is_sequence_value_unknown(F.shape(data)):
|
|
1465
|
+
return tensor_copy_slice_from_tuple(data, tuple_index, value)
|
|
1466
|
+
dim1_start, dim1_stop, _ = const_utils.normalize_slice(
|
|
1467
|
+
tuple_index[1], data.shape[1])
|
|
1468
|
+
if dim1_stop - dim1_start <= 0:
|
|
1469
|
+
return data
|
|
1470
|
+
dim0_start = tuple_index[0] if tuple_index[0] >= 0 else tuple_index[0] + data.shape[0]
|
|
1471
|
+
start = (dim0_start, dim1_start)
|
|
1472
|
+
stop = (dim0_start + 1, dim1_stop)
|
|
1473
|
+
step = (1, 1)
|
|
1474
|
+
value_shape = (dim1_stop - dim1_start,) + \
|
|
1475
|
+
const_utils.tuple_slice(data.shape, 2, None)
|
|
1476
|
+
value = _broadcast(value_shape, value)
|
|
1477
|
+
return copy_slice(data, value.astype(data.dtype), start, stop, step)
|
|
1478
|
+
tuple_index, _, non_zero_shapes = _handle_bool_tensor(tuple_index)
|
|
1479
|
+
|
|
1480
|
+
for non_zero_shape in non_zero_shapes:
|
|
1481
|
+
if F.reduce_min(non_zero_shape) == 0:
|
|
1482
|
+
return data
|
|
1483
|
+
value = value.astype(data.dtype)
|
|
1484
|
+
special_index, tuple_index, new_value_shape, idx_advanced, _broadcast_data_shape \
|
|
1485
|
+
= _pre_setitem_by_tuple(data, tuple_index, value)
|
|
1486
|
+
if special_index == 0:
|
|
1487
|
+
return data
|
|
1488
|
+
value = F.reshape(value, new_value_shape)
|
|
1489
|
+
if not tuple_index or special_index == 1:
|
|
1490
|
+
data[True] = value
|
|
1491
|
+
return data
|
|
1492
|
+
|
|
1493
|
+
empty_broadcast_data_shape = False
|
|
1494
|
+
if isinstance(_broadcast_data_shape, Tensor) and _broadcast_data_shape == Tensor([0]):
|
|
1495
|
+
empty_broadcast_data_shape = True
|
|
1496
|
+
if isinstance(_broadcast_data_shape, tuple) and not _broadcast_data_shape:
|
|
1497
|
+
empty_broadcast_data_shape = True
|
|
1498
|
+
indices = _tensor_index_setitem(
|
|
1499
|
+
data, tuple_index, value, idx_advanced, empty_broadcast_data_shape)
|
|
1500
|
+
|
|
1501
|
+
updates = _generate_updates_from_tensor(
|
|
1502
|
+
data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
|
|
1503
|
+
if is_parameter(data):
|
|
1504
|
+
F.scatter_nd_update(data, indices, updates)
|
|
1505
|
+
return data
|
|
1506
|
+
return F.tensor_scatter_update(data, indices, updates)
|
|
1507
|
+
|
|
1508
|
+
def tensor_itemset_by_tuple_with_tensor(data, tuple_index, value):
|
|
1309
1509
|
"""Assigns the tensor by tuple with tensor value."""
|
|
1310
1510
|
op_name = const_utils.TENSOR_SETITEM
|
|
1311
1511
|
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
|
@@ -1323,7 +1523,6 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
|
|
1323
1523
|
value_shape = (dim1_stop - dim1_start,) + const_utils.tuple_slice(data.shape, 2, None)
|
|
1324
1524
|
value = _broadcast(value_shape, value)
|
|
1325
1525
|
return copy_slice(data, value.astype(data.dtype), start, stop, step)
|
|
1326
|
-
|
|
1327
1526
|
tuple_index, value, idx_advanced = remove_expanded_dims(tuple_index, F.shape(data), value)
|
|
1328
1527
|
|
|
1329
1528
|
if tuple_index is False:
|
|
@@ -1351,7 +1550,7 @@ def tensor_setitem_by_tuple_with_sequence(data, tuple_index, value):
|
|
|
1351
1550
|
|
|
1352
1551
|
def tensor_setitem_by_number_with_number(data, index, value):
|
|
1353
1552
|
"""Assigns the tensor by number with number value."""
|
|
1354
|
-
value = F.
|
|
1553
|
+
value = F.cast(value, F.dtype(data))
|
|
1355
1554
|
return tensor_setitem_by_number_with_tensor(data, index, value)
|
|
1356
1555
|
|
|
1357
1556
|
|
|
@@ -1386,7 +1585,7 @@ def tensor_setitem_by_ellipsis_with_number(data, value):
|
|
|
1386
1585
|
data_shape = F.shape(data)
|
|
1387
1586
|
data_dtype = F.dtype(data)
|
|
1388
1587
|
if F.is_sequence_value_unknown(data_shape):
|
|
1389
|
-
value = F.
|
|
1588
|
+
value = F.cast(value, F.dtype(data))
|
|
1390
1589
|
return tensor_setitem_by_ellipsis_with_tensor(data, value)
|
|
1391
1590
|
return F.fill(data_dtype, data_shape, value)
|
|
1392
1591
|
|
|
@@ -1418,6 +1617,7 @@ def tensor_setitem_by_ellipsis_with_sequence(data, value):
|
|
|
1418
1617
|
def tensor_setitem_by_bool(data, index, value):
|
|
1419
1618
|
"""Assigns a value to the tensor by boolean."""
|
|
1420
1619
|
data_shape = F.shape(data)
|
|
1620
|
+
data_dtype = F.dtype(data)
|
|
1421
1621
|
if not index:
|
|
1422
1622
|
data_shape = (0,) + data_shape
|
|
1423
1623
|
if isinstance(value, (list, tuple)):
|
|
@@ -1429,6 +1629,7 @@ def tensor_setitem_by_bool(data, index, value):
|
|
|
1429
1629
|
|
|
1430
1630
|
if F.is_sequence_value_unknown(data_shape) and index:
|
|
1431
1631
|
data_shape = F.dyn_shape(data)
|
|
1632
|
+
value = value.astype(data_dtype)
|
|
1432
1633
|
data = ops.broadcast_to(value, data_shape)
|
|
1433
1634
|
return data
|
|
1434
1635
|
value_shape = F.shape(value)
|
|
@@ -1436,7 +1637,7 @@ def tensor_setitem_by_bool(data, index, value):
|
|
|
1436
1637
|
if index:
|
|
1437
1638
|
value = F.reshape(value, source_shape)
|
|
1438
1639
|
value = _broadcast(data_shape, value)
|
|
1439
|
-
data = value
|
|
1640
|
+
data = F.cast(value, data_dtype)
|
|
1440
1641
|
return data
|
|
1441
1642
|
|
|
1442
1643
|
|
|
@@ -398,10 +398,16 @@ def slice2indices(input_slice, shape):
|
|
|
398
398
|
return False
|
|
399
399
|
ndim = len(shape)
|
|
400
400
|
mesh = list()
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
grids
|
|
404
|
-
|
|
401
|
+
range_op = P.Range()
|
|
402
|
+
cast_op = P.Cast()
|
|
403
|
+
grids = [
|
|
404
|
+
range_op(cast_op(start, mstype.int64), cast_op(stop, mstype.int64),
|
|
405
|
+
cast_op(step, mstype.int64))
|
|
406
|
+
]
|
|
407
|
+
grids += [
|
|
408
|
+
range_op(Tensor(0, mstype.int64), cast_op(dim_size, mstype.int64),
|
|
409
|
+
Tensor(1, mstype.int64)) for dim_size in shape[1:]
|
|
410
|
+
]
|
|
405
411
|
for j, grid in enumerate(grids):
|
|
406
412
|
mesh.append(P.Reshape()(grid, tuple(
|
|
407
413
|
[grid.size if j == t else 1 for t in range(ndim)])))
|
|
@@ -543,7 +549,7 @@ def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_ty
|
|
|
543
549
|
updates_shape = indices_shape + data_shape[1:]
|
|
544
550
|
else:
|
|
545
551
|
updates_shape = indices_shape[:-1] + data_shape[indices_shape[-1]:]
|
|
546
|
-
return P.
|
|
552
|
+
return P.FillV2()(updates_shape, P.Cast()(value, data_dtype))
|
|
547
553
|
|
|
548
554
|
|
|
549
555
|
def generate_updates_shape(data_shape, index_shape, op_type, is_dynamic):
|
|
@@ -844,7 +850,7 @@ def sequence_to_index(sequence, dim_size):
|
|
|
844
850
|
return False
|
|
845
851
|
if all(isinstance(i, bool) for i in sequence):
|
|
846
852
|
if dim_size is None:
|
|
847
|
-
|
|
853
|
+
return Tensor(sequence)
|
|
848
854
|
seq_size = len(sequence)
|
|
849
855
|
if seq_size != dim_size:
|
|
850
856
|
raise IndexError(f'dimension is {dim_size} but corresponding boolean dimension is {seq_size}')
|
|
@@ -865,11 +871,11 @@ def int_to_index(i, shape):
|
|
|
865
871
|
_check(i, dim_size)
|
|
866
872
|
i = (i + dim_size) % dim_size
|
|
867
873
|
if len(shape) == 1:
|
|
868
|
-
return P.
|
|
874
|
+
return P.FillV2()((1, 1), P.Cast()(i, mstype.int64))
|
|
869
875
|
mesh = list()
|
|
870
876
|
ndim = len(shape) - 1
|
|
871
877
|
for j, size in enumerate(shape[1:]):
|
|
872
|
-
grid = P.Range()(Tensor(0, mstype.int64), P.
|
|
878
|
+
grid = P.Range()(Tensor(0, mstype.int64), P.Cast()(size, mstype.int64), Tensor(1, mstype.int64))
|
|
873
879
|
mesh.append(P.Reshape()(grid, tuple([size if j == t else 1 for t in range(ndim)])))
|
|
874
880
|
shapes = map(P.Shape(), mesh)
|
|
875
881
|
out_shape = infer_out_shape(*shapes)
|
|
@@ -877,7 +883,8 @@ def int_to_index(i, shape):
|
|
|
877
883
|
for arr in mesh:
|
|
878
884
|
mesh_arrays.append(P.BroadcastTo(out_shape)(arr))
|
|
879
885
|
index = P.Stack(-1)(mesh_arrays)
|
|
880
|
-
return P.Concat(-1)((P.
|
|
886
|
+
return P.Concat(-1)((P.FillV2()(P.Shape()(index)[:-1] + (1,),
|
|
887
|
+
P.Cast()(i, mstype.int64)), index))
|
|
881
888
|
|
|
882
889
|
|
|
883
890
|
@constexpr
|