mindspore 2.4.0__cp311-cp311-macosx_10_15_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -0
- mindspore/__init__.py +53 -0
- mindspore/_c_dataengine.cpython-311-darwin.so +0 -0
- mindspore/_c_expression.cpython-311-darwin.so +0 -0
- mindspore/_c_mindrecord.cpython-311-darwin.so +0 -0
- mindspore/_check_jit_forbidden_api.py +106 -0
- mindspore/_checkparam.py +1419 -0
- mindspore/_extends/__init__.py +23 -0
- mindspore/_extends/builtin_operations.py +224 -0
- mindspore/_extends/graph_kernel/__init__.py +17 -0
- mindspore/_extends/graph_kernel/model/__init__.py +19 -0
- mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
- mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
- mindspore/_extends/graph_kernel/model/model.py +553 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
- mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
- mindspore/_extends/graph_kernel/splitter.py +140 -0
- mindspore/_extends/graph_kernel/utils.py +28 -0
- mindspore/_extends/parallel_compile/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
- mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
- mindspore/_extends/parse/__init__.py +49 -0
- mindspore/_extends/parse/compile_config.py +299 -0
- mindspore/_extends/parse/namespace.py +136 -0
- mindspore/_extends/parse/parser.py +1448 -0
- mindspore/_extends/parse/resources.py +213 -0
- mindspore/_extends/parse/standard_method.py +4475 -0
- mindspore/_extends/parse/trope.py +97 -0
- mindspore/_extends/pijit/__init__.py +23 -0
- mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
- mindspore/_extends/remote/__init__.py +19 -0
- mindspore/_extends/remote/kernel_build_server.py +199 -0
- mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
- mindspore/_extends/utils.py +68 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +433 -0
- mindspore/boost/__init__.py +42 -0
- mindspore/boost/adasum.py +319 -0
- mindspore/boost/base.py +535 -0
- mindspore/boost/boost.py +400 -0
- mindspore/boost/boost_cell_wrapper.py +790 -0
- mindspore/boost/dim_reduce.py +323 -0
- mindspore/boost/grad_accumulation.py +79 -0
- mindspore/boost/grad_freeze.py +382 -0
- mindspore/boost/group_loss_scale_manager.py +166 -0
- mindspore/boost/less_batch_normalization.py +174 -0
- mindspore/common/__init__.py +86 -0
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_decorator.py +50 -0
- mindspore/common/_jit_fallback_utils.py +110 -0
- mindspore/common/_monad.py +25 -0
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_adapter.py +74 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +46 -0
- mindspore/common/_stub_tensor.py +210 -0
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/_utils.py +122 -0
- mindspore/common/api.py +2064 -0
- mindspore/common/auto_dynamic_shape.py +507 -0
- mindspore/common/dtype.py +422 -0
- mindspore/common/dump.py +130 -0
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +254 -0
- mindspore/common/hook_handle.py +143 -0
- mindspore/common/initializer.py +880 -0
- mindspore/common/jit_config.py +98 -0
- mindspore/common/lazy_inline.py +240 -0
- mindspore/common/mindir_util.py +111 -0
- mindspore/common/mutable.py +234 -0
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +1081 -0
- mindspore/common/recompute.py +292 -0
- mindspore/common/seed.py +260 -0
- mindspore/common/sparse_tensor.py +1175 -0
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +5039 -0
- mindspore/communication/__init__.py +37 -0
- mindspore/communication/_comm_helper.py +501 -0
- mindspore/communication/_hccl_management.py +297 -0
- mindspore/communication/comm_func.py +1395 -0
- mindspore/communication/management.py +673 -0
- mindspore/config/op_info.config +533 -0
- mindspore/context.py +2077 -0
- mindspore/dataset/__init__.py +90 -0
- mindspore/dataset/audio/__init__.py +61 -0
- mindspore/dataset/audio/transforms.py +3690 -0
- mindspore/dataset/audio/utils.py +386 -0
- mindspore/dataset/audio/validators.py +1172 -0
- mindspore/dataset/callback/__init__.py +20 -0
- mindspore/dataset/callback/ds_callback.py +368 -0
- mindspore/dataset/callback/validators.py +32 -0
- mindspore/dataset/core/__init__.py +13 -0
- mindspore/dataset/core/config.py +1095 -0
- mindspore/dataset/core/datatypes.py +101 -0
- mindspore/dataset/core/py_util_helpers.py +65 -0
- mindspore/dataset/core/validator_helpers.py +781 -0
- mindspore/dataset/debug/__init__.py +21 -0
- mindspore/dataset/debug/debug_hook.py +97 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +124 -0
- mindspore/dataset/engine/cache_admin.py +47 -0
- mindspore/dataset/engine/cache_client.py +129 -0
- mindspore/dataset/engine/datasets.py +4582 -0
- mindspore/dataset/engine/datasets_audio.py +911 -0
- mindspore/dataset/engine/datasets_standard_format.py +543 -0
- mindspore/dataset/engine/datasets_text.py +2161 -0
- mindspore/dataset/engine/datasets_user_defined.py +1184 -0
- mindspore/dataset/engine/datasets_vision.py +4816 -0
- mindspore/dataset/engine/iterators.py +371 -0
- mindspore/dataset/engine/obs/__init__.py +23 -0
- mindspore/dataset/engine/obs/config_loader.py +68 -0
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
- mindspore/dataset/engine/obs/util.py +482 -0
- mindspore/dataset/engine/offload.py +596 -0
- mindspore/dataset/engine/queue.py +304 -0
- mindspore/dataset/engine/samplers.py +895 -0
- mindspore/dataset/engine/serializer_deserializer.py +159 -0
- mindspore/dataset/engine/validators.py +2895 -0
- mindspore/dataset/text/__init__.py +51 -0
- mindspore/dataset/text/transforms.py +1703 -0
- mindspore/dataset/text/utils.py +715 -0
- mindspore/dataset/text/validators.py +642 -0
- mindspore/dataset/transforms/__init__.py +45 -0
- mindspore/dataset/transforms/c_transforms.py +638 -0
- mindspore/dataset/transforms/py_transforms.py +393 -0
- mindspore/dataset/transforms/py_transforms_util.py +255 -0
- mindspore/dataset/transforms/transforms.py +1260 -0
- mindspore/dataset/transforms/validators.py +410 -0
- mindspore/dataset/utils/__init__.py +19 -0
- mindspore/dataset/utils/browse_dataset.py +190 -0
- mindspore/dataset/utils/line_reader.py +126 -0
- mindspore/dataset/vision/__init__.py +65 -0
- mindspore/dataset/vision/c_transforms.py +2641 -0
- mindspore/dataset/vision/py_transforms.py +2120 -0
- mindspore/dataset/vision/py_transforms_util.py +1660 -0
- mindspore/dataset/vision/transforms.py +7295 -0
- mindspore/dataset/vision/utils.py +863 -0
- mindspore/dataset/vision/validators.py +1483 -0
- mindspore/default_config.py +2 -0
- mindspore/experimental/__init__.py +20 -0
- mindspore/experimental/es/__init__.py +22 -0
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/experimental/es/embedding_service_layer.py +581 -0
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/experimental/llm_boost/atb/__init__.py +23 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/map_parameter.py +309 -0
- mindspore/experimental/optim/__init__.py +40 -0
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +193 -0
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +290 -0
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +1371 -0
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +262 -0
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +156 -0
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +357 -0
- mindspore/include/OWNERS +7 -0
- mindspore/include/api/allocator.h +97 -0
- mindspore/include/api/callback/callback.h +93 -0
- mindspore/include/api/callback/ckpt_saver.h +41 -0
- mindspore/include/api/callback/loss_monitor.h +33 -0
- mindspore/include/api/callback/lr_scheduler.h +51 -0
- mindspore/include/api/callback/time_monitor.h +34 -0
- mindspore/include/api/callback/train_accuracy.h +37 -0
- mindspore/include/api/cell.h +90 -0
- mindspore/include/api/cfg.h +82 -0
- mindspore/include/api/context.h +602 -0
- mindspore/include/api/data_type.h +47 -0
- mindspore/include/api/delegate.h +178 -0
- mindspore/include/api/delegate_api.h +75 -0
- mindspore/include/api/dual_abi_helper.h +208 -0
- mindspore/include/api/format.h +28 -0
- mindspore/include/api/graph.h +46 -0
- mindspore/include/api/kernel.h +58 -0
- mindspore/include/api/kernel_api.h +168 -0
- mindspore/include/api/metrics/accuracy.h +36 -0
- mindspore/include/api/metrics/metrics.h +41 -0
- mindspore/include/api/model.h +438 -0
- mindspore/include/api/model_group.h +91 -0
- mindspore/include/api/model_parallel_runner.h +168 -0
- mindspore/include/api/serialization.h +185 -0
- mindspore/include/api/status.h +192 -0
- mindspore/include/api/types.h +431 -0
- mindspore/include/api/visible.h +41 -0
- mindspore/include/c_api/context_c.h +179 -0
- mindspore/include/c_api/data_type_c.h +52 -0
- mindspore/include/c_api/format_c.h +46 -0
- mindspore/include/c_api/model_c.h +347 -0
- mindspore/include/c_api/status_c.h +79 -0
- mindspore/include/c_api/tensor_c.h +146 -0
- mindspore/include/c_api/types_c.h +67 -0
- mindspore/include/dataset/config.h +163 -0
- mindspore/include/dataset/constants.h +363 -0
- mindspore/include/dataset/execute.h +196 -0
- mindspore/include/dataset/text.h +1092 -0
- mindspore/include/dataset/transforms.h +638 -0
- mindspore/include/dataset/vision.h +2129 -0
- mindspore/include/dataset/vision_ascend.h +206 -0
- mindspore/include/dataset/vision_lite.h +625 -0
- mindspore/lib/libavcodec.59.dylib +0 -0
- mindspore/lib/libavdevice.59.dylib +0 -0
- mindspore/lib/libavfilter.8.dylib +0 -0
- mindspore/lib/libavformat.59.dylib +0 -0
- mindspore/lib/libavutil.57.dylib +0 -0
- mindspore/lib/libdnnl.2.dylib +0 -0
- mindspore/lib/libicudata.69.dylib +0 -0
- mindspore/lib/libicui18n.69.dylib +0 -0
- mindspore/lib/libicuuc.69.dylib +0 -0
- mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
- mindspore/lib/libmindspore_backend.dylib +0 -0
- mindspore/lib/libmindspore_common.dylib +0 -0
- mindspore/lib/libmindspore_core.dylib +0 -0
- mindspore/lib/libmindspore_glog.0.dylib +0 -0
- mindspore/lib/libmindspore_gpr.15.dylib +0 -0
- mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
- mindspore/lib/libmindspore_grpc.15.dylib +0 -0
- mindspore/lib/libmindspore_np_dtype.dylib +0 -0
- mindspore/lib/libmindspore_ops.dylib +0 -0
- mindspore/lib/libmindspore_upb.15.dylib +0 -0
- mindspore/lib/libnnacl.dylib +0 -0
- mindspore/lib/libopencv_core.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
- mindspore/lib/libps_cache.dylib +0 -0
- mindspore/lib/libswresample.4.dylib +0 -0
- mindspore/lib/libswscale.6.dylib +0 -0
- mindspore/lib/libtinyxml2.8.dylib +0 -0
- mindspore/log.py +633 -0
- mindspore/mindrecord/__init__.py +43 -0
- mindspore/mindrecord/common/__init__.py +17 -0
- mindspore/mindrecord/common/constant.py +20 -0
- mindspore/mindrecord/common/enums.py +44 -0
- mindspore/mindrecord/common/exceptions.py +311 -0
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +174 -0
- mindspore/mindrecord/filewriter.py +722 -0
- mindspore/mindrecord/mindpage.py +210 -0
- mindspore/mindrecord/shardheader.py +141 -0
- mindspore/mindrecord/shardindexgenerator.py +74 -0
- mindspore/mindrecord/shardreader.py +117 -0
- mindspore/mindrecord/shardsegment.py +128 -0
- mindspore/mindrecord/shardutils.py +185 -0
- mindspore/mindrecord/shardwriter.py +237 -0
- mindspore/mindrecord/tools/__init__.py +17 -0
- mindspore/mindrecord/tools/cifar10.py +140 -0
- mindspore/mindrecord/tools/cifar100.py +153 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
- mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
- mindspore/mindrecord/tools/csv_to_mr.py +200 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
- mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
- mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
- mindspore/mint/__init__.py +1586 -0
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/linalg/__init__.py +22 -0
- mindspore/mint/nn/__init__.py +757 -0
- mindspore/mint/nn/functional.py +679 -0
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +206 -0
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +73 -0
- mindspore/nn/__init__.py +47 -0
- mindspore/nn/cell.py +2787 -0
- mindspore/nn/dynamic_lr.py +482 -0
- mindspore/nn/grad/__init__.py +21 -0
- mindspore/nn/grad/cell_grad.py +196 -0
- mindspore/nn/layer/__init__.py +63 -0
- mindspore/nn/layer/activation.py +1822 -0
- mindspore/nn/layer/basic.py +1629 -0
- mindspore/nn/layer/channel_shuffle.py +90 -0
- mindspore/nn/layer/combined.py +248 -0
- mindspore/nn/layer/container.py +734 -0
- mindspore/nn/layer/conv.py +1505 -0
- mindspore/nn/layer/dense.py +204 -0
- mindspore/nn/layer/embedding.py +869 -0
- mindspore/nn/layer/image.py +661 -0
- mindspore/nn/layer/math.py +1069 -0
- mindspore/nn/layer/normalization.py +1273 -0
- mindspore/nn/layer/padding.py +880 -0
- mindspore/nn/layer/pooling.py +2302 -0
- mindspore/nn/layer/rnn_cells.py +388 -0
- mindspore/nn/layer/rnns.py +849 -0
- mindspore/nn/layer/thor_layer.py +963 -0
- mindspore/nn/layer/timedistributed.py +155 -0
- mindspore/nn/layer/transformer.py +823 -0
- mindspore/nn/learning_rate_schedule.py +512 -0
- mindspore/nn/loss/__init__.py +36 -0
- mindspore/nn/loss/loss.py +2924 -0
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/__init__.py +45 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
- mindspore/nn/optim/ada_grad.py +217 -0
- mindspore/nn/optim/adadelta.py +206 -0
- mindspore/nn/optim/adafactor.py +448 -0
- mindspore/nn/optim/adam.py +1297 -0
- mindspore/nn/optim/adamax.py +220 -0
- mindspore/nn/optim/adasum.py +548 -0
- mindspore/nn/optim/asgd.py +216 -0
- mindspore/nn/optim/ftrl.py +401 -0
- mindspore/nn/optim/lamb.py +296 -0
- mindspore/nn/optim/lars.py +202 -0
- mindspore/nn/optim/lazyadam.py +533 -0
- mindspore/nn/optim/momentum.py +239 -0
- mindspore/nn/optim/optimizer.py +1034 -0
- mindspore/nn/optim/proximal_ada_grad.py +242 -0
- mindspore/nn/optim/rmsprop.py +264 -0
- mindspore/nn/optim/rprop.py +251 -0
- mindspore/nn/optim/sgd.py +237 -0
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/optim/thor.py +1310 -0
- mindspore/nn/probability/__init__.py +22 -0
- mindspore/nn/probability/bijector/__init__.py +35 -0
- mindspore/nn/probability/bijector/bijector.py +337 -0
- mindspore/nn/probability/bijector/exp.py +65 -0
- mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
- mindspore/nn/probability/bijector/invert.py +126 -0
- mindspore/nn/probability/bijector/power_transform.py +196 -0
- mindspore/nn/probability/bijector/scalar_affine.py +167 -0
- mindspore/nn/probability/bijector/softplus.py +189 -0
- mindspore/nn/probability/bnn_layers/__init__.py +29 -0
- mindspore/nn/probability/bnn_layers/_util.py +46 -0
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
- mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
- mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
- mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
- mindspore/nn/probability/distribution/__init__.py +56 -0
- mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
- mindspore/nn/probability/distribution/_utils/utils.py +362 -0
- mindspore/nn/probability/distribution/bernoulli.py +334 -0
- mindspore/nn/probability/distribution/beta.py +391 -0
- mindspore/nn/probability/distribution/categorical.py +435 -0
- mindspore/nn/probability/distribution/cauchy.py +383 -0
- mindspore/nn/probability/distribution/distribution.py +827 -0
- mindspore/nn/probability/distribution/exponential.py +350 -0
- mindspore/nn/probability/distribution/gamma.py +391 -0
- mindspore/nn/probability/distribution/geometric.py +335 -0
- mindspore/nn/probability/distribution/gumbel.py +257 -0
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +272 -0
- mindspore/nn/probability/distribution/logistic.py +379 -0
- mindspore/nn/probability/distribution/normal.py +336 -0
- mindspore/nn/probability/distribution/poisson.py +288 -0
- mindspore/nn/probability/distribution/student_t.py +149 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
- mindspore/nn/probability/distribution/uniform.py +375 -0
- mindspore/nn/reinforcement/__init__.py +24 -0
- mindspore/nn/reinforcement/_batch_read_write.py +142 -0
- mindspore/nn/reinforcement/_tensors_queue.py +152 -0
- mindspore/nn/reinforcement/tensor_array.py +145 -0
- mindspore/nn/sparse/__init__.py +23 -0
- mindspore/nn/sparse/sparse.py +147 -0
- mindspore/nn/wrap/__init__.py +49 -0
- mindspore/nn/wrap/cell_wrapper.py +968 -0
- mindspore/nn/wrap/grad_reducer.py +608 -0
- mindspore/nn/wrap/loss_scale.py +694 -0
- mindspore/numpy/__init__.py +121 -0
- mindspore/numpy/array_creations.py +2731 -0
- mindspore/numpy/array_ops.py +2629 -0
- mindspore/numpy/dtypes.py +185 -0
- mindspore/numpy/fft.py +966 -0
- mindspore/numpy/logic_ops.py +936 -0
- mindspore/numpy/math_ops.py +5911 -0
- mindspore/numpy/utils.py +214 -0
- mindspore/numpy/utils_const.py +565 -0
- mindspore/ops/__init__.py +56 -0
- mindspore/ops/_constants.py +30 -0
- mindspore/ops/_grad_experimental/__init__.py +31 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
- mindspore/ops/_grad_experimental/grad_base.py +143 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
- mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
- mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
- mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
- mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
- mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
- mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
- mindspore/ops/_op_impl/__init__.py +23 -0
- mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
- mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
- mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/acos.py +32 -0
- mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
- mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
- mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
- mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
- mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
- mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/angle.py +31 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/asin.py +32 -0
- mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
- mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
- mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
- mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
- mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
- mindspore/ops/_op_impl/aicpu/cast.py +225 -0
- mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
- mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
- mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
- mindspore/ops/_op_impl/aicpu/complex.py +32 -0
- mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
- mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
- mindspore/ops/_op_impl/aicpu/concat.py +57 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +42 -0
- mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/cos.py +34 -0
- mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
- mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
- mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
- mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
- mindspore/ops/_op_impl/aicpu/div.py +41 -0
- mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
- mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
- mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
- mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
- mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
- mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
- mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/exp.py +37 -0
- mindspore/ops/_op_impl/aicpu/expand.py +45 -0
- mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
- mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
- mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
- mindspore/ops/_op_impl/aicpu/eye.py +44 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
- mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
- mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/gather.py +46 -0
- mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
- mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
- mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
- mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
- mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
- mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
- mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
- mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
- mindspore/ops/_op_impl/aicpu/identity.py +42 -0
- mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
- mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
- mindspore/ops/_op_impl/aicpu/imag.py +31 -0
- mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
- mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
- mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
- mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
- mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
- mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
- mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
- mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
- mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
- mindspore/ops/_op_impl/aicpu/log.py +37 -0
- mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
- mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
- mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
- mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
- mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
- mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
- mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
- mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
- mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
- mindspore/ops/_op_impl/aicpu/median.py +39 -0
- mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
- mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
- mindspore/ops/_op_impl/aicpu/mul.py +43 -0
- mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/neg.py +36 -0
- mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
- mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
- mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
- mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
- mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
- mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
- mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
- mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
- mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
- mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
- mindspore/ops/_op_impl/aicpu/padding.py +41 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
- mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/pow.py +39 -0
- mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
- mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
- mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/real.py +31 -0
- mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
- mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
- mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
- mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
- mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
- mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
- mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/round.py +34 -0
- mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
- mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
- mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
- mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
- mindspore/ops/_op_impl/aicpu/select.py +45 -0
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
- mindspore/ops/_op_impl/aicpu/sign.py +36 -0
- mindspore/ops/_op_impl/aicpu/sin.py +34 -0
- mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
- mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/slice.py +59 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
- mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
- mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
- mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
- mindspore/ops/_op_impl/aicpu/split.py +45 -0
- mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
- mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/square.py +35 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
- mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
- mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
- mindspore/ops/_op_impl/aicpu/stack.py +45 -0
- mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
- mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
- mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
- mindspore/ops/_op_impl/aicpu/stft.py +70 -0
- mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/sub.py +41 -0
- mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
- mindspore/ops/_op_impl/aicpu/tan.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/tile.py +56 -0
- mindspore/ops/_op_impl/aicpu/topk.py +34 -0
- mindspore/ops/_op_impl/aicpu/trace.py +40 -0
- mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
- mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril.py +42 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
- mindspore/ops/_op_impl/aicpu/triu.py +43 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
- mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
- mindspore/ops/_op_impl/aicpu/unique.py +31 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
- mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
- mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
- mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
- mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
- mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
- mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
- mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
- mindspore/ops/_op_impl/akg/__init__.py +19 -0
- mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
- mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
- mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
- mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
- mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
- mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
- mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
- mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
- mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
- mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
- mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
- mindspore/ops/_op_impl/cpu/__init__.py +78 -0
- mindspore/ops/_op_impl/cpu/adam.py +49 -0
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
- mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
- mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
- mindspore/ops/_op_impl/cpu/cast.py +171 -0
- mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
- mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
- mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
- mindspore/ops/_op_impl/cpu/div.py +32 -0
- mindspore/ops/_op_impl/cpu/dropout.py +31 -0
- mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
- mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
- mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
- mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
- mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
- mindspore/ops/_op_impl/cpu/hswish.py +32 -0
- mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
- mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
- mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
- mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
- mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
- mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/maximum.py +35 -0
- mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
- mindspore/ops/_op_impl/cpu/minimum.py +40 -0
- mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
- mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
- mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
- mindspore/ops/_op_impl/cpu/mul.py +32 -0
- mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
- mindspore/ops/_op_impl/cpu/pad.py +32 -0
- mindspore/ops/_op_impl/cpu/pow.py +32 -0
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
- mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
- mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
- mindspore/ops/_op_impl/cpu/range.py +34 -0
- mindspore/ops/_op_impl/cpu/real_div.py +33 -0
- mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
- mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
- mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/split.py +34 -0
- mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
- mindspore/ops/_op_impl/cpu/stack.py +38 -0
- mindspore/ops/_op_impl/cpu/sub.py +32 -0
- mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
- mindspore/ops/_op_impl/cpu/tile.py +37 -0
- mindspore/ops/_op_impl/cpu/top_k.py +31 -0
- mindspore/ops/_op_impl/cpu/transpose.py +39 -0
- mindspore/ops/_primitive_cache.py +90 -0
- mindspore/ops/_register_for_op.py +73 -0
- mindspore/ops/_utils/__init__.py +20 -0
- mindspore/ops/_utils/utils.py +147 -0
- mindspore/ops/_vmap/__init__.py +25 -0
- mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
- mindspore/ops/_vmap/vmap_base.py +533 -0
- mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
- mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
- mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
- mindspore/ops/_vmap/vmap_image_ops.py +194 -0
- mindspore/ops/_vmap/vmap_math_ops.py +993 -0
- mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
- mindspore/ops/_vmap/vmap_other_ops.py +105 -0
- mindspore/ops/_vmap/vmap_random_ops.py +122 -0
- mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
- mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
- mindspore/ops/composite/__init__.py +71 -0
- mindspore/ops/composite/base.py +1318 -0
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +125 -0
- mindspore/ops/composite/multitype_ops/__init__.py +77 -0
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
- mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
- mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
- mindspore/ops/deprecated.py +315 -0
- mindspore/ops/function/__init__.py +782 -0
- mindspore/ops/function/array_func.py +7226 -0
- mindspore/ops/function/clip_func.py +384 -0
- mindspore/ops/function/debug_func.py +181 -0
- mindspore/ops/function/fft_func.py +44 -0
- mindspore/ops/function/grad/__init__.py +34 -0
- mindspore/ops/function/grad/grad_func.py +1425 -0
- mindspore/ops/function/image_func.py +292 -0
- mindspore/ops/function/linalg_func.py +416 -0
- mindspore/ops/function/math_func.py +12228 -0
- mindspore/ops/function/nn_func.py +8609 -0
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +134 -0
- mindspore/ops/function/random_func.py +1715 -0
- mindspore/ops/function/reshard_func.py +104 -0
- mindspore/ops/function/sparse_func.py +884 -0
- mindspore/ops/function/sparse_unary_func.py +2422 -0
- mindspore/ops/function/spectral_func.py +150 -0
- mindspore/ops/function/vmap_func.py +117 -0
- mindspore/ops/functional.py +464 -0
- mindspore/ops/op_info_register.py +1572 -0
- mindspore/ops/operations/__init__.py +722 -0
- mindspore/ops/operations/_csr_ops.py +403 -0
- mindspore/ops/operations/_custom_grad.py +181 -0
- mindspore/ops/operations/_embedding_cache_ops.py +307 -0
- mindspore/ops/operations/_grad_ops.py +2978 -0
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +2544 -0
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +601 -0
- mindspore/ops/operations/_ocr_ops.py +379 -0
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_pyfunc_registry.py +58 -0
- mindspore/ops/operations/_quant_ops.py +1844 -0
- mindspore/ops/operations/_rl_inner_ops.py +1231 -0
- mindspore/ops/operations/_scalar_ops.py +106 -0
- mindspore/ops/operations/_sequence_ops.py +1155 -0
- mindspore/ops/operations/_sparse_grad_ops.py +56 -0
- mindspore/ops/operations/_tensor_array.py +359 -0
- mindspore/ops/operations/_thor_ops.py +807 -0
- mindspore/ops/operations/array_ops.py +6124 -0
- mindspore/ops/operations/comm_ops.py +1985 -0
- mindspore/ops/operations/control_ops.py +127 -0
- mindspore/ops/operations/custom_ops.py +1129 -0
- mindspore/ops/operations/debug_ops.py +678 -0
- mindspore/ops/operations/image_ops.py +1041 -0
- mindspore/ops/operations/inner_ops.py +697 -0
- mindspore/ops/operations/linalg_ops.py +95 -0
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +73 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
- mindspore/ops/operations/math_ops.py +5095 -0
- mindspore/ops/operations/nn_ops.py +9575 -0
- mindspore/ops/operations/other_ops.py +874 -0
- mindspore/ops/operations/random_ops.py +1288 -0
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/rl_ops.py +288 -0
- mindspore/ops/operations/sparse_ops.py +2753 -0
- mindspore/ops/operations/spectral_ops.py +111 -0
- mindspore/ops/primitive.py +1046 -0
- mindspore/ops/signature.py +54 -0
- mindspore/ops/vm_impl_registry.py +91 -0
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +252 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +1099 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +1052 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +145 -0
- mindspore/ops_generate/pyboost_utils.py +367 -0
- mindspore/ops_generate/template.py +261 -0
- mindspore/parallel/__init__.py +30 -0
- mindspore/parallel/_auto_parallel_context.py +1486 -0
- mindspore/parallel/_cell_wrapper.py +174 -0
- mindspore/parallel/_cost_model_context.py +700 -0
- mindspore/parallel/_dp_allreduce_fusion.py +159 -0
- mindspore/parallel/_offload_context.py +275 -0
- mindspore/parallel/_parallel_serialization.py +561 -0
- mindspore/parallel/_ps_context.py +242 -0
- mindspore/parallel/_recovery_context.py +110 -0
- mindspore/parallel/_tensor.py +730 -0
- mindspore/parallel/_transformer/__init__.py +35 -0
- mindspore/parallel/_transformer/layers.py +765 -0
- mindspore/parallel/_transformer/loss.py +251 -0
- mindspore/parallel/_transformer/moe.py +693 -0
- mindspore/parallel/_transformer/op_parallel_config.py +222 -0
- mindspore/parallel/_transformer/transformer.py +3119 -0
- mindspore/parallel/_utils.py +612 -0
- mindspore/parallel/algo_parameter_config.py +400 -0
- mindspore/parallel/checkpoint_transform.py +650 -0
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +352 -0
- mindspore/parallel/cluster/process_entity/_utils.py +101 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +14 -0
- mindspore/parallel/mpi/_mpi_config.py +116 -0
- mindspore/parallel/parameter_broadcast.py +151 -0
- mindspore/parallel/shard.py +481 -0
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/profiler/__init__.py +28 -0
- mindspore/profiler/common/__init__.py +14 -0
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/exceptions/__init__.py +14 -0
- mindspore/profiler/common/exceptions/error_code.py +83 -0
- mindspore/profiler/common/exceptions/exceptions.py +286 -0
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/struct_type.py +118 -0
- mindspore/profiler/common/util.py +472 -0
- mindspore/profiler/common/validator/__init__.py +14 -0
- mindspore/profiler/common/validator/validate_path.py +84 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +254 -0
- mindspore/profiler/parser/__init__.py +14 -0
- mindspore/profiler/parser/aicpu_data_parser.py +272 -0
- mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
- mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
- mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
- mindspore/profiler/parser/ascend_flops_generator.py +116 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
- mindspore/profiler/parser/ascend_op_generator.py +334 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
- mindspore/profiler/parser/base_timeline_generator.py +483 -0
- mindspore/profiler/parser/container.py +229 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
- mindspore/profiler/parser/flops_parser.py +531 -0
- mindspore/profiler/parser/framework_enum.py +111 -0
- mindspore/profiler/parser/framework_parser.py +464 -0
- mindspore/profiler/parser/framework_struct.py +61 -0
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/hccl_parser.py +573 -0
- mindspore/profiler/parser/hwts_log_parser.py +122 -0
- mindspore/profiler/parser/integrator.py +526 -0
- mindspore/profiler/parser/memory_usage_parser.py +277 -0
- mindspore/profiler/parser/minddata_analyzer.py +800 -0
- mindspore/profiler/parser/minddata_parser.py +186 -0
- mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
- mindspore/profiler/parser/op_intermediate_parser.py +149 -0
- mindspore/profiler/parser/optime_parser.py +250 -0
- mindspore/profiler/parser/profiler_info.py +213 -0
- mindspore/profiler/parser/step_trace_parser.py +666 -0
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +1922 -0
- mindspore/rewrite/__init__.py +28 -0
- mindspore/rewrite/api/__init__.py +17 -0
- mindspore/rewrite/api/node.py +519 -0
- mindspore/rewrite/api/node_type.py +53 -0
- mindspore/rewrite/api/pattern_engine.py +490 -0
- mindspore/rewrite/api/scoped_value.py +181 -0
- mindspore/rewrite/api/symbol_tree.py +497 -0
- mindspore/rewrite/ast_helpers/__init__.py +25 -0
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
- mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
- mindspore/rewrite/common/__init__.py +19 -0
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/error_log.py +39 -0
- mindspore/rewrite/common/event.py +28 -0
- mindspore/rewrite/common/namer.py +271 -0
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/common/observable.py +44 -0
- mindspore/rewrite/common/observer.py +54 -0
- mindspore/rewrite/node/__init__.py +22 -0
- mindspore/rewrite/node/call_function.py +95 -0
- mindspore/rewrite/node/cell_container.py +139 -0
- mindspore/rewrite/node/control_flow.py +113 -0
- mindspore/rewrite/node/node.py +1428 -0
- mindspore/rewrite/node/node_manager.py +283 -0
- mindspore/rewrite/node/node_topological_manager.py +223 -0
- mindspore/rewrite/parsers/__init__.py +29 -0
- mindspore/rewrite/parsers/arguments_parser.py +63 -0
- mindspore/rewrite/parsers/assign_parser.py +852 -0
- mindspore/rewrite/parsers/attribute_parser.py +57 -0
- mindspore/rewrite/parsers/class_def_parser.py +289 -0
- mindspore/rewrite/parsers/constant_parser.py +104 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +61 -0
- mindspore/rewrite/parsers/function_def_parser.py +84 -0
- mindspore/rewrite/parsers/if_parser.py +85 -0
- mindspore/rewrite/parsers/module_parser.py +117 -0
- mindspore/rewrite/parsers/parser.py +43 -0
- mindspore/rewrite/parsers/parser_register.py +86 -0
- mindspore/rewrite/parsers/return_parser.py +37 -0
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
- mindspore/rewrite/sparsify/sparsify.py +112 -0
- mindspore/rewrite/sparsify/utils.py +179 -0
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
- mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
- mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
- mindspore/run_check/__init__.py +20 -0
- mindspore/run_check/_check_version.py +507 -0
- mindspore/run_check/run_check.py +66 -0
- mindspore/safeguard/__init__.py +18 -0
- mindspore/safeguard/rewrite_obfuscation.py +875 -0
- mindspore/scipy/__init__.py +18 -0
- mindspore/scipy/fft.py +264 -0
- mindspore/scipy/linalg.py +919 -0
- mindspore/scipy/ops.py +165 -0
- mindspore/scipy/ops_grad.py +115 -0
- mindspore/scipy/ops_wrapper.py +74 -0
- mindspore/scipy/optimize/__init__.py +20 -0
- mindspore/scipy/optimize/_bfgs.py +230 -0
- mindspore/scipy/optimize/_lagrange.py +201 -0
- mindspore/scipy/optimize/_lbfgs.py +146 -0
- mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
- mindspore/scipy/optimize/line_search.py +370 -0
- mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
- mindspore/scipy/optimize/minimize.py +200 -0
- mindspore/scipy/utils.py +156 -0
- mindspore/scipy/utils_const.py +246 -0
- mindspore/train/__init__.py +48 -0
- mindspore/train/_utils.py +465 -0
- mindspore/train/amp.py +935 -0
- mindspore/train/anf_ir_pb2.py +1517 -0
- mindspore/train/callback/__init__.py +44 -0
- mindspore/train/callback/_backup_and_restore.py +117 -0
- mindspore/train/callback/_callback.py +613 -0
- mindspore/train/callback/_checkpoint.py +814 -0
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_dataset_graph.py +150 -0
- mindspore/train/callback/_early_stop.py +239 -0
- mindspore/train/callback/_flops_collector.py +239 -0
- mindspore/train/callback/_history.py +92 -0
- mindspore/train/callback/_lambda_callback.py +80 -0
- mindspore/train/callback/_landscape.py +1049 -0
- mindspore/train/callback/_loss_monitor.py +107 -0
- mindspore/train/callback/_lr_scheduler_callback.py +76 -0
- mindspore/train/callback/_on_request_exit.py +298 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
- mindspore/train/callback/_summary_collector.py +1184 -0
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/callback/_time_monitor.py +141 -0
- mindspore/train/checkpoint_pb2.py +233 -0
- mindspore/train/data_sink.py +219 -0
- mindspore/train/dataset_helper.py +692 -0
- mindspore/train/lineage_pb2.py +1260 -0
- mindspore/train/loss_scale_manager.py +213 -0
- mindspore/train/memory_profiling_pb2.py +298 -0
- mindspore/train/metrics/__init__.py +175 -0
- mindspore/train/metrics/accuracy.py +133 -0
- mindspore/train/metrics/auc.py +129 -0
- mindspore/train/metrics/bleu_score.py +170 -0
- mindspore/train/metrics/confusion_matrix.py +700 -0
- mindspore/train/metrics/cosine_similarity.py +109 -0
- mindspore/train/metrics/dice.py +116 -0
- mindspore/train/metrics/error.py +175 -0
- mindspore/train/metrics/fbeta.py +167 -0
- mindspore/train/metrics/hausdorff_distance.py +333 -0
- mindspore/train/metrics/loss.py +97 -0
- mindspore/train/metrics/mean_surface_distance.py +189 -0
- mindspore/train/metrics/metric.py +373 -0
- mindspore/train/metrics/occlusion_sensitivity.py +225 -0
- mindspore/train/metrics/perplexity.py +133 -0
- mindspore/train/metrics/precision.py +160 -0
- mindspore/train/metrics/recall.py +159 -0
- mindspore/train/metrics/roc.py +223 -0
- mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
- mindspore/train/metrics/topk.py +167 -0
- mindspore/train/mind_ir_pb2.py +1908 -0
- mindspore/train/model.py +2252 -0
- mindspore/train/node_strategy_pb2.py +653 -0
- mindspore/train/print_pb2.py +184 -0
- mindspore/train/profiling_parallel_pb2.py +151 -0
- mindspore/train/serialization.py +3325 -0
- mindspore/train/summary/__init__.py +23 -0
- mindspore/train/summary/_lineage_adapter.py +41 -0
- mindspore/train/summary/_summary_adapter.py +496 -0
- mindspore/train/summary/_writer_pool.py +207 -0
- mindspore/train/summary/enums.py +56 -0
- mindspore/train/summary/summary_record.py +581 -0
- mindspore/train/summary/writer.py +167 -0
- mindspore/train/summary_pb2.py +1165 -0
- mindspore/train/train_thor/__init__.py +20 -0
- mindspore/train/train_thor/convert_utils.py +268 -0
- mindspore/train/train_thor/dataset_helper.py +192 -0
- mindspore/train/train_thor/model_thor.py +257 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -0
- mindspore-2.4.0.dist-info/METADATA +352 -0
- mindspore-2.4.0.dist-info/RECORD +1387 -0
- mindspore-2.4.0.dist-info/WHEEL +5 -0
- mindspore-2.4.0.dist-info/entry_points.txt +3 -0
- mindspore-2.4.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1459 @@
|
|
|
1
|
+
# Copyright 2020-2023 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
"""constexpr util"""
|
|
17
|
+
from __future__ import absolute_import
|
|
18
|
+
from enum import IntEnum
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
|
|
22
|
+
from mindspore.ops import functional as F
|
|
23
|
+
from mindspore.ops import operations as P
|
|
24
|
+
from mindspore.ops.composite import base
|
|
25
|
+
from mindspore.ops._primitive_cache import _get_cache_prim
|
|
26
|
+
from mindspore.ops.operations._inner_ops import TensorCopySlices, SliceGetItem, \
|
|
27
|
+
TopTypeof, IsParameter, GetitemTensorIndexInfo, SetitemTensorIndexInfo, \
|
|
28
|
+
SelectView, CopyWithSlice
|
|
29
|
+
from mindspore.ops.operations._sequence_ops import TensorToTuple, TensorToScalar, TupleToTensor
|
|
30
|
+
from mindspore.common import dtype as mstype
|
|
31
|
+
from mindspore.common._register_for_tensor import tensor_operator_registry
|
|
32
|
+
from mindspore.common.initializer import Zero
|
|
33
|
+
from mindspore.common import Tensor, CSRTensor, COOTensor, mutable
|
|
34
|
+
from mindspore import ops
|
|
35
|
+
from mindspore.ops.primitive import _primexpr
|
|
36
|
+
from mindspore import _checkparam as validator
|
|
37
|
+
from mindspore.common._stub_tensor import _convert_stub
|
|
38
|
+
|
|
39
|
+
slice_get_item = SliceGetItem()
|
|
40
|
+
hyper_map = base.HyperMap()
|
|
41
|
+
stack = P.Stack(axis=-1)
|
|
42
|
+
copy_slice = TensorCopySlices()
|
|
43
|
+
toptypeof = TopTypeof()
|
|
44
|
+
is_parameter = IsParameter()
|
|
45
|
+
getitem_tensor_index_info = GetitemTensorIndexInfo(const_utils.is_ascend())
|
|
46
|
+
setitem_tensor_index_info = SetitemTensorIndexInfo(const_utils.is_ascend())
|
|
47
|
+
|
|
48
|
+
selevt_view = SelectView()
|
|
49
|
+
copy_with_slice = CopyWithSlice()
|
|
50
|
+
|
|
51
|
+
def strided_slice(data, begin_strides, end_strides, step_strides, begin_mask=0, end_mask=0, ellipsis_mask=0,
|
|
52
|
+
new_axis_mask=0, shrink_axis_mask=0):
|
|
53
|
+
"""strided_slice primitive cache"""
|
|
54
|
+
strided_slice_ = _get_cache_prim(P.StridedSlice)(begin_mask, end_mask, ellipsis_mask, new_axis_mask,
|
|
55
|
+
shrink_axis_mask)
|
|
56
|
+
return strided_slice_(data, begin_strides, end_strides, step_strides)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class ValueTransferType(IntEnum):
|
|
60
|
+
"""Transfer op types of handling tensor getitem/setitem"""
|
|
61
|
+
kUnknown = 0
|
|
62
|
+
kTensorScatterUpdate = 1
|
|
63
|
+
kExpandDims = 2
|
|
64
|
+
kBroadCast = 3
|
|
65
|
+
kCast = 4
|
|
66
|
+
kSelect = 5
|
|
67
|
+
kGather = 6
|
|
68
|
+
kStrideSlice = 7
|
|
69
|
+
kStrideSliceWithMask = 8
|
|
70
|
+
kGatherND = 9
|
|
71
|
+
kScatterNdUpdate = 10
|
|
72
|
+
kReshape = 11
|
|
73
|
+
kSelectView = 12
|
|
74
|
+
kUnsqueeze = 13
|
|
75
|
+
kCopyView = 14
|
|
76
|
+
kScatterND = 15
|
|
77
|
+
kNumberToTensor = 16
|
|
78
|
+
kHandleSequenceValue = 17
|
|
79
|
+
kByPass = 18
|
|
80
|
+
kReSetItemByIndex = 19
|
|
81
|
+
kCopySlice = 20
|
|
82
|
+
kSetItemByBool = 21
|
|
83
|
+
kEmptyTensor = 22
|
|
84
|
+
kSetItemByEllipsis = 23
|
|
85
|
+
kFormatIndexTensor = 24
|
|
86
|
+
kGetitemByBoolTensor = 25
|
|
87
|
+
kSetitemByBoolTensor = 26
|
|
88
|
+
kJustReturn = 27
|
|
89
|
+
kRaiseIndexError = 28
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def data_update(transfer_types, args, data, new_index, value=None):
|
|
93
|
+
"""
|
|
94
|
+
We finally generate a new tensor when handling tensor getitem/setitem
|
|
95
|
+
by transfer data and value with index.
|
|
96
|
+
"""
|
|
97
|
+
origin_data = data
|
|
98
|
+
for transfer_type, arg in zip(transfer_types, args):
|
|
99
|
+
if transfer_type == ValueTransferType.kUnknown:
|
|
100
|
+
raise IndexError(f"Inlvaid transfer type {transfer_type}.")
|
|
101
|
+
if transfer_type <= ValueTransferType.kScatterND:
|
|
102
|
+
data = data_update_by_ops(transfer_type, arg, data, new_index, origin_data, value)
|
|
103
|
+
if transfer_type == ValueTransferType.kJustReturn:
|
|
104
|
+
return _convert_stub(arg)
|
|
105
|
+
if transfer_type == ValueTransferType.kSetItemByBool:
|
|
106
|
+
return tensor_setitem_by_bool(data, new_index, value)
|
|
107
|
+
if transfer_type == ValueTransferType.kCopySlice:
|
|
108
|
+
return copy_slice(data, value.astype(data.dtype), arg[0], arg[1], arg[2])
|
|
109
|
+
if transfer_type == ValueTransferType.kSetItemByEllipsis:
|
|
110
|
+
return tensor_setitem_by_ellipsis(data, new_index, value)
|
|
111
|
+
if transfer_type == ValueTransferType.kReSetItemByIndex:
|
|
112
|
+
data[new_index] = value
|
|
113
|
+
return data
|
|
114
|
+
if transfer_type == ValueTransferType.kEmptyTensor:
|
|
115
|
+
return handle_empty_tensor(arg, data)
|
|
116
|
+
if transfer_type == ValueTransferType.kFormatIndexTensor:
|
|
117
|
+
new_index = format_index_tensor(new_index, arg)
|
|
118
|
+
if transfer_type == ValueTransferType.kGetitemByBoolTensor:
|
|
119
|
+
return F.gather_nd(data, new_index.nonzero())
|
|
120
|
+
if transfer_type == ValueTransferType.kSetitemByBoolTensor:
|
|
121
|
+
return handle_setitem_by_bool_tensor(data, new_index, value)
|
|
122
|
+
if transfer_type == ValueTransferType.kRaiseIndexError:
|
|
123
|
+
raise IndexError(
|
|
124
|
+
f'index {arg[0]} is out of bounds for dimension with size {arg[1]}')
|
|
125
|
+
return data
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def data_update_by_ops(transfer_type, arg, data, new_index, origin_data, value=None):
|
|
129
|
+
"""
|
|
130
|
+
Generate a new tensor when handling tensor getitem/setitem
|
|
131
|
+
by ops.
|
|
132
|
+
"""
|
|
133
|
+
if transfer_type == ValueTransferType.kStrideSliceWithMask:
|
|
134
|
+
stride_info, mask_index = arg[0], arg[1]
|
|
135
|
+
data = strided_slice(data, stride_info[0], stride_info[1], stride_info[2],
|
|
136
|
+
mask_index[0], mask_index[1], 0, 0, mask_index[2])
|
|
137
|
+
elif transfer_type == ValueTransferType.kGatherND:
|
|
138
|
+
if isinstance(new_index, list):
|
|
139
|
+
new_index = handle_multi_dim_index_tensor(new_index, arg)
|
|
140
|
+
new_index = format_index_tensor(new_index, (None, F.shape(data)[:F.shape(new_index)[-1]]))
|
|
141
|
+
data = F.gather_nd(data, new_index)
|
|
142
|
+
elif transfer_type == ValueTransferType.kTensorScatterUpdate:
|
|
143
|
+
if isinstance(new_index, list):
|
|
144
|
+
new_index = handle_multi_dim_index_tensor(new_index, arg)
|
|
145
|
+
data = F.tensor_scatter_update(data, new_index, value)
|
|
146
|
+
elif transfer_type == ValueTransferType.kScatterNdUpdate:
|
|
147
|
+
F.scatter_nd_update(data, new_index, value)
|
|
148
|
+
elif transfer_type == ValueTransferType.kSelect:
|
|
149
|
+
data = F.select(Tensor(new_index), value, data)
|
|
150
|
+
elif transfer_type == ValueTransferType.kSelectView:
|
|
151
|
+
data = selevt_view(data, arg[0], arg[1])
|
|
152
|
+
elif transfer_type == ValueTransferType.kCopyView:
|
|
153
|
+
value = _broadcast(F.shape(data), F.cast(value, F.dtype(data)))
|
|
154
|
+
data = copy_with_slice(data, value)
|
|
155
|
+
return origin_data
|
|
156
|
+
elif transfer_type == ValueTransferType.kReshape:
|
|
157
|
+
data = F.reshape(data, arg)
|
|
158
|
+
elif transfer_type == ValueTransferType.kGather:
|
|
159
|
+
data = F.gather(data, new_index, 0)
|
|
160
|
+
elif transfer_type == ValueTransferType.kExpandDims:
|
|
161
|
+
data = F.expand_dims(data, 0)
|
|
162
|
+
elif transfer_type == ValueTransferType.kUnsqueeze:
|
|
163
|
+
data = F.unsqueeze(data, arg)
|
|
164
|
+
elif transfer_type == ValueTransferType.kStrideSlice:
|
|
165
|
+
data = strided_slice(data, arg[0], arg[1], arg[2])
|
|
166
|
+
else:
|
|
167
|
+
raise IndexError(f"Inlvaid transfer type {transfer_type}.")
|
|
168
|
+
return data
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def value_update(transfer_types, args, data, value):
|
|
172
|
+
"""Transfer value before set value to tensor when handling tensor setitem"""
|
|
173
|
+
for transfer_type, arg in zip(transfer_types, args):
|
|
174
|
+
if transfer_type == ValueTransferType.kByPass:
|
|
175
|
+
continue
|
|
176
|
+
if transfer_type == ValueTransferType.kNumberToTensor:
|
|
177
|
+
value = F.cast(value, F.dtype(data))
|
|
178
|
+
elif transfer_type == ValueTransferType.kHandleSequenceValue:
|
|
179
|
+
op_type, index = arg
|
|
180
|
+
if op_type == const_utils.SET_ITEM_BY_ONE_TENSOR:
|
|
181
|
+
index = Tensor(index)
|
|
182
|
+
value = _generate_updates_from_sequence(
|
|
183
|
+
data, index, value, op_type)
|
|
184
|
+
elif transfer_type == ValueTransferType.kExpandDims:
|
|
185
|
+
value = F.expand_dims(value, arg)
|
|
186
|
+
elif transfer_type == ValueTransferType.kBroadCast:
|
|
187
|
+
value = _broadcast(arg, value.astype(F.dtype(data)))
|
|
188
|
+
elif transfer_type == ValueTransferType.kCast:
|
|
189
|
+
value = F.cast(value, F.dtype(data))
|
|
190
|
+
elif transfer_type == ValueTransferType.kReshape:
|
|
191
|
+
value = F.reshape(value, arg)
|
|
192
|
+
elif transfer_type == ValueTransferType.kScatterND:
|
|
193
|
+
value = F.scatter_nd(arg[0], value, arg[1])
|
|
194
|
+
else:
|
|
195
|
+
raise IndexError(f"Inlvaid transfer type {transfer_type}.")
|
|
196
|
+
return value
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def _tensor_getitem(self, index):
|
|
200
|
+
"""Handle tensor getitem"""
|
|
201
|
+
new_index, tensor_update_types, tensor_update_args = getitem_tensor_index_info(
|
|
202
|
+
self, index)
|
|
203
|
+
return data_update(tensor_update_types, tensor_update_args, self, new_index)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def _tensor_setitem(self, index, value):
|
|
207
|
+
"""Handle tensor setitem"""
|
|
208
|
+
setitem_info = setitem_tensor_index_info(self, index, value)
|
|
209
|
+
new_index = setitem_info[0]
|
|
210
|
+
v_transfer_types = setitem_info[1]
|
|
211
|
+
v_transfer_args = setitem_info[2]
|
|
212
|
+
data_update_types = setitem_info[3]
|
|
213
|
+
data_update_args = setitem_info[4]
|
|
214
|
+
value = value_update(v_transfer_types, v_transfer_args, self, value)
|
|
215
|
+
output = data_update(data_update_types, data_update_args, self, new_index, value)
|
|
216
|
+
if new_index == "view":
|
|
217
|
+
return (self,)
|
|
218
|
+
return output
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
setattr(tensor_operator_registry, "__getitem__", _tensor_getitem)
|
|
222
|
+
setattr(tensor_operator_registry, "__setitem__", _tensor_setitem)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def _tensor_add(self, other):
|
|
226
|
+
if isinstance(other, (tuple, list)):
|
|
227
|
+
other = sequence_to_tensor(other, F.dtype(self))
|
|
228
|
+
if isinstance(other, COOTensor):
|
|
229
|
+
return other + self
|
|
230
|
+
return F.add(self, other)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _tensor_sub(self, other):
|
|
234
|
+
if isinstance(self, (tuple, list)):
|
|
235
|
+
self = sequence_to_tensor(self, F.dtype(other))
|
|
236
|
+
if isinstance(other, (tuple, list)):
|
|
237
|
+
other = sequence_to_tensor(other, F.dtype(self))
|
|
238
|
+
if isinstance(other, COOTensor):
|
|
239
|
+
return F.tensor_scatter_sub(self, other.indices, other.values)
|
|
240
|
+
return F.sub(self, other)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def _tensor_mul(self, other):
|
|
244
|
+
if isinstance(other, (tuple, list)):
|
|
245
|
+
other = sequence_to_tensor(other, F.dtype(self))
|
|
246
|
+
elif isinstance(other, (CSRTensor, COOTensor)):
|
|
247
|
+
return other * self
|
|
248
|
+
return F.mul(self, other)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def _tensor_matmul(self, other):
|
|
252
|
+
return F.matmul(self, other)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def _tensor_div(self, other):
|
|
256
|
+
if isinstance(self, (tuple, list)):
|
|
257
|
+
self = sequence_to_tensor(self, F.dtype(other))
|
|
258
|
+
if isinstance(other, (tuple, list)):
|
|
259
|
+
other = sequence_to_tensor(other, F.dtype(self))
|
|
260
|
+
return F.div(self, other)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def _tensor_mod(self, other):
|
|
264
|
+
if isinstance(self, (tuple, list)):
|
|
265
|
+
self = sequence_to_tensor(self, F.dtype(other))
|
|
266
|
+
if isinstance(other, (tuple, list)):
|
|
267
|
+
other = sequence_to_tensor(other, F.dtype(self))
|
|
268
|
+
return F.floormod(self, other)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def _tensor_pow(self, other):
|
|
272
|
+
if isinstance(other, (tuple, list)):
|
|
273
|
+
other = sequence_to_tensor(other, F.dtype(self))
|
|
274
|
+
return F.tensor_pow(self, other)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def _tensor_rpow(self, other):
|
|
278
|
+
if isinstance(other, (tuple, list)):
|
|
279
|
+
other = sequence_to_tensor(other, F.dtype(self))
|
|
280
|
+
return F.tensor_pow(other, self)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def _tensor_floordiv(self, other):
|
|
284
|
+
if isinstance(self, (tuple, list)):
|
|
285
|
+
self = sequence_to_tensor(self, F.dtype(other))
|
|
286
|
+
if isinstance(other, (tuple, list)):
|
|
287
|
+
other = sequence_to_tensor(other, F.dtype(self))
|
|
288
|
+
return F.floordiv(self, other)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
setattr(tensor_operator_registry, '__add__', _tensor_add)
|
|
292
|
+
setattr(tensor_operator_registry, '__sub__', _tensor_sub)
|
|
293
|
+
setattr(tensor_operator_registry, '__mul__', _tensor_mul)
|
|
294
|
+
setattr(tensor_operator_registry, '__matmul__', _tensor_matmul)
|
|
295
|
+
setattr(tensor_operator_registry, '__truediv__', _tensor_div)
|
|
296
|
+
setattr(tensor_operator_registry, '__mod__', _tensor_mod)
|
|
297
|
+
setattr(tensor_operator_registry, '__pow__', _tensor_pow)
|
|
298
|
+
setattr(tensor_operator_registry, '__rpow__', _tensor_rpow)
|
|
299
|
+
setattr(tensor_operator_registry, '__floordiv__', _tensor_floordiv)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def _scalar_to_tensor(input_x):
|
|
303
|
+
if ops.isconstant(input_x):
|
|
304
|
+
return P.ScalarToTensor()(input_x, ops.dtype(input_x))
|
|
305
|
+
# use add Tensor([0]) cast scalar to tensor.
|
|
306
|
+
return ops.add(input_x, mutable(Tensor(0)))
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
@_primexpr
|
|
310
|
+
def _check_scalar_tensor_args(args):
|
|
311
|
+
"""For the item, check that the index of the scalar tensor is set."""
|
|
312
|
+
if args not in ((None,), ()):
|
|
313
|
+
const_utils.raise_value_error("For item, the index of scalar Tensor should not be set.")
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def tensor_item(data, *args):
|
|
317
|
+
"""Tensor getitem by index whose dtype is int or tuple with int."""
|
|
318
|
+
# transform a.item(tuple(int)) -> a.item(int1,int2...intN)
|
|
319
|
+
if data.ndim == 0:
|
|
320
|
+
_check_scalar_tensor_args(args)
|
|
321
|
+
return TensorToScalar()(data)
|
|
322
|
+
if len(args) == 1 and isinstance(args[0], tuple):
|
|
323
|
+
args = args[0]
|
|
324
|
+
|
|
325
|
+
args_types = hyper_map(F.typeof, args)
|
|
326
|
+
if not args or const_utils.judge_index_type(args_types[0], mstype.type_none):
|
|
327
|
+
if data.shape == (1,):
|
|
328
|
+
return TensorToScalar()(data[0])
|
|
329
|
+
const_utils.raise_value_error("Can only convert an array of size 1 to a Python scalar")
|
|
330
|
+
|
|
331
|
+
if not const_utils.judge_indexes_types(args_types, mstype.int64):
|
|
332
|
+
const_utils.raise_type_error("The index object cannot be interpreted as an integer")
|
|
333
|
+
|
|
334
|
+
if len(args) == data.ndim:
|
|
335
|
+
return tensor_index_by_tuple(data, args)
|
|
336
|
+
if len(args) > 1:
|
|
337
|
+
const_utils.raise_value_error("Incorrect number of indices for array")
|
|
338
|
+
output = _tensor_index_by_integer(F.reshape(data, (-1,)), args[0])
|
|
339
|
+
return TensorToScalar()(output)
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def tensor_itemset(data, *args):
|
|
343
|
+
"""Tensor setitem by index and value."""
|
|
344
|
+
if not args:
|
|
345
|
+
const_utils.raise_value_error("'Tensor.itemset()' must have at least one argument, but got None.")
|
|
346
|
+
if len(args) == 2:
|
|
347
|
+
if const_utils.judge_index_type(F.typeof(args[0]), mstype.int64):
|
|
348
|
+
return tensor_itemset_by_number_with_number(data, args[0], args[1])
|
|
349
|
+
if isinstance(args[0], tuple):
|
|
350
|
+
return tensor_itemset_by_tuple_with_number(data, args[0], args[1])
|
|
351
|
+
const_utils.raise_type_error("The index object cannot be interpreted as an integer")
|
|
352
|
+
if len(args) > 2:
|
|
353
|
+
exp_msg = const_utils.gen_exception_msg("'Tensor.itemset()' must have at most 2 argument, but got {}.",
|
|
354
|
+
len(args))
|
|
355
|
+
const_utils.raise_value_error(exp_msg)
|
|
356
|
+
return tensor_itemset_with_number(data, args[0])
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
setattr(tensor_operator_registry, "item", tensor_item)
|
|
360
|
+
setattr(tensor_operator_registry, "itemset", tensor_itemset)
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def tensor_itemset_with_number(data, number_value):
|
|
364
|
+
"""set value of tensor whose shape is (1,)"""
|
|
365
|
+
if not const_utils.judge_index_type(F.typeof(number_value), mstype.number_type):
|
|
366
|
+
exp_msg = const_utils.gen_exception_msg(
|
|
367
|
+
"'Tensor.itemset()' only support number input, but got {}", number_value)
|
|
368
|
+
const_utils.raise_index_error(exp_msg)
|
|
369
|
+
if data.shape != (1,):
|
|
370
|
+
exp_msg = const_utils.gen_exception_msg(
|
|
371
|
+
"Only tensor which shape is (1,) support 1 arg that means omit index, "
|
|
372
|
+
"but the tensor shape is {} and got 1 input.", data.shape)
|
|
373
|
+
const_utils.raise_index_error(exp_msg)
|
|
374
|
+
return const_utils.make_tensor((number_value,), F.dtype(data))
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def tensor_itemset_by_number_with_number(data, int_index, number_value):
|
|
378
|
+
flatten_data = F.reshape(data, (-1,))
|
|
379
|
+
itemset_data = tensor_setitem_by_number_with_number(flatten_data, int_index, number_value)
|
|
380
|
+
res_data = F.reshape(itemset_data, F.shape(data))
|
|
381
|
+
return res_data
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def tensor_itemset_by_tuple_with_number(data, tuple_index, nubmer_value):
|
|
385
|
+
if len(tuple_index) != data.ndim:
|
|
386
|
+
exp_msg = const_utils.gen_exception_msg(
|
|
387
|
+
"Tuple index len({}) is not same to tensor dimension({})", len(tuple_index), data.ndim)
|
|
388
|
+
const_utils.raise_index_error(exp_msg)
|
|
389
|
+
nubmer_value = F.cast(nubmer_value, F.dtype(data))
|
|
390
|
+
return tensor_itemset_by_tuple_with_tensor(data, tuple_index, nubmer_value)
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def _broadcast(broadcast_shape, x):
|
|
394
|
+
"""Broadcast tensor to the required shape."""
|
|
395
|
+
if F.shape(x) == broadcast_shape:
|
|
396
|
+
return x
|
|
397
|
+
return F.broadcast_to(x, broadcast_shape)
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def _transform_indexing_tensor(broadcast_shape, final_shape, new_shape, item):
|
|
401
|
+
"""Transform indexing tensor to the required."""
|
|
402
|
+
item = _broadcast(broadcast_shape, item)
|
|
403
|
+
return _broadcast(final_shape, F.reshape(item, new_shape))
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
def _transform_ellipsis_to_slice(data, tuple_index, op_name):
|
|
407
|
+
"""
|
|
408
|
+
Check if the tuple index len is longer than the data's dims and transform ellipsis in the indices
|
|
409
|
+
to several slice.
|
|
410
|
+
"""
|
|
411
|
+
data_shape = F.shape(data)
|
|
412
|
+
indexes_types = hyper_map(toptypeof, tuple_index)
|
|
413
|
+
slice_positions, ellipsis_positions, _, int_positions, _, tensor_positions, sequence_positions = \
|
|
414
|
+
const_utils.get_pos_of_indexes_types(indexes_types, op_name)
|
|
415
|
+
|
|
416
|
+
ellipsis_occupy_dims = data.ndim - (len(slice_positions) + len(int_positions) +
|
|
417
|
+
len(tensor_positions) + len(sequence_positions))
|
|
418
|
+
ellipsis_cnt = len(ellipsis_positions)
|
|
419
|
+
|
|
420
|
+
if ellipsis_occupy_dims < 0:
|
|
421
|
+
if ellipsis_cnt >= 0:
|
|
422
|
+
exp_msg = const_utils.gen_exception_msg(
|
|
423
|
+
"Tuple index {} out rang of tensor shape {}.", tuple_index, data_shape)
|
|
424
|
+
const_utils.raise_index_error(exp_msg)
|
|
425
|
+
|
|
426
|
+
tuple_index_new = ()
|
|
427
|
+
for i, index in enumerate(tuple_index):
|
|
428
|
+
if i in ellipsis_positions:
|
|
429
|
+
for _ in range(ellipsis_occupy_dims):
|
|
430
|
+
empty_slice = const_utils.make_empty_slice()
|
|
431
|
+
tuple_index_new += (empty_slice,)
|
|
432
|
+
else:
|
|
433
|
+
tuple_index_new += (index,)
|
|
434
|
+
return tuple_index_new
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
def handle_empty_tensor(arg, data):
|
|
438
|
+
"""handle data update with empty tensor"""
|
|
439
|
+
if 0 in arg:
|
|
440
|
+
init_func = Zero()
|
|
441
|
+
init_func.__enable_zero_dim__ = True
|
|
442
|
+
return Tensor(shape=arg, dtype=data.dtype, init=init_func)
|
|
443
|
+
return const_utils.make_tensor([], data.dtype, arg)
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def handle_multi_dim_index_tensor(new_index, arg):
|
|
447
|
+
"""handle data update with multi dim index tensor"""
|
|
448
|
+
slice_cnt = 0
|
|
449
|
+
new_indies_tensor = []
|
|
450
|
+
if len(arg) == 1:
|
|
451
|
+
broadcast_shape = arg[0]
|
|
452
|
+
new_index = hyper_map(F.partial(Tensor), new_index)
|
|
453
|
+
broadcast_tensors = hyper_map(
|
|
454
|
+
F.partial(_broadcast, broadcast_shape), new_index)
|
|
455
|
+
new_broadcast_tensors = ()
|
|
456
|
+
for tensor in broadcast_tensors:
|
|
457
|
+
new_broadcast_tensors += (F.cast(tensor, mstype.int64),)
|
|
458
|
+
new_index = stack(new_broadcast_tensors)
|
|
459
|
+
return new_index
|
|
460
|
+
broadcast_shape, final_shape, index_tensor_new_shape, slice_shapes, tensor_positions, fancy_position = arg
|
|
461
|
+
for i, index in enumerate(new_index):
|
|
462
|
+
if i in tensor_positions:
|
|
463
|
+
transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape,
|
|
464
|
+
Tensor(index))
|
|
465
|
+
new_indies_tensor.append(F.cast(transform_tensor, mstype.int64))
|
|
466
|
+
else:
|
|
467
|
+
shape = const_utils.compute_slice_shape(
|
|
468
|
+
slice_shapes, len(broadcast_shape), slice_cnt, fancy_position)
|
|
469
|
+
array = Tensor(index).reshape(shape)
|
|
470
|
+
slice_index_tensor = _broadcast(final_shape, array)
|
|
471
|
+
new_indies_tensor.append(F.cast(slice_index_tensor, mstype.int64))
|
|
472
|
+
slice_cnt += 1
|
|
473
|
+
new_index = stack(new_indies_tensor)
|
|
474
|
+
return new_index
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
def format_index_tensor(index, arg):
|
|
478
|
+
"""Format index tensor when tensor less than 0"""
|
|
479
|
+
format_indices, format_dims = arg
|
|
480
|
+
if isinstance(index, list):
|
|
481
|
+
for format_idx, format_dim in zip(format_indices, format_dims):
|
|
482
|
+
index_tensor = index[format_idx]
|
|
483
|
+
index[format_idx] = F.select(index_tensor < 0, index_tensor + format_dim, index_tensor)
|
|
484
|
+
return index
|
|
485
|
+
index = Tensor(index)
|
|
486
|
+
format_dims = Tensor(format_dims)
|
|
487
|
+
return F.select(index < 0, index + format_dims, index)
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
def handle_setitem_by_bool_tensor(data, index, value):
|
|
491
|
+
"""Set a tensor item by a bool tensor with a tensor."""
|
|
492
|
+
value = F.cast(value, F.dtype(data))
|
|
493
|
+
indices = index.nonzero()
|
|
494
|
+
if indices.shape[0] == 0:
|
|
495
|
+
return data
|
|
496
|
+
value_shape = (indices.shape[0],) + data.shape[index.ndim:]
|
|
497
|
+
value = _broadcast(value_shape, value)
|
|
498
|
+
value = F.scatter_nd(indices, value, data.shape)
|
|
499
|
+
index = index.reshape(const_utils.generate_padding_shape(index.shape, len(data.shape)))
|
|
500
|
+
index = _broadcast(data.shape, index)
|
|
501
|
+
result = F.select(index, value, data)
|
|
502
|
+
return result
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
def _expand_data_dims(data, tuple_index):
|
|
506
|
+
"""expand the data's dim with 'None' and 'Boolean' in tuple_index"""
|
|
507
|
+
indexes_types = hyper_map(toptypeof, tuple_index)
|
|
508
|
+
expand_positions, tuple_index_new = (), ()
|
|
509
|
+
for i, (index, index_type) in enumerate(zip(tuple_index, indexes_types)):
|
|
510
|
+
if isinstance(index_type, mstype.NoneType):
|
|
511
|
+
tuple_index_new += (const_utils.make_empty_slice(),)
|
|
512
|
+
expand_positions += (i,)
|
|
513
|
+
elif isinstance(index_type, mstype.Bool):
|
|
514
|
+
if not index:
|
|
515
|
+
const_utils.raise_index_error("Bool element of tuple index must be 'True', but got 'False'.")
|
|
516
|
+
tuple_index_new += (const_utils.make_tensor([0], mstype.int64),)
|
|
517
|
+
expand_positions += (i,)
|
|
518
|
+
else:
|
|
519
|
+
tuple_index_new += (index,)
|
|
520
|
+
|
|
521
|
+
for dim in expand_positions:
|
|
522
|
+
data = F.expand_dims(data, dim)
|
|
523
|
+
|
|
524
|
+
return data, tuple_index_new
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
def _convert_list_index_to_tensor(list_index):
|
|
528
|
+
"""convert list to tensor"""
|
|
529
|
+
has_bool = False
|
|
530
|
+
has_int = False
|
|
531
|
+
has_no_bool_int = False
|
|
532
|
+
for idx in list_index:
|
|
533
|
+
if isinstance(idx, bool):
|
|
534
|
+
has_bool = True
|
|
535
|
+
elif isinstance(idx, int):
|
|
536
|
+
has_int = True
|
|
537
|
+
else:
|
|
538
|
+
has_no_bool_int = True
|
|
539
|
+
|
|
540
|
+
all_bool = has_bool and not has_int and not has_no_bool_int
|
|
541
|
+
all_int = has_int and not has_bool and not has_no_bool_int
|
|
542
|
+
all_bool_or_int = not has_no_bool_int
|
|
543
|
+
|
|
544
|
+
if all_int:
|
|
545
|
+
index_tensor = TupleToTensor()(tuple(list_index), mstype.int64)
|
|
546
|
+
return index_tensor
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
if all_bool:
|
|
550
|
+
index_tensor = TupleToTensor()(tuple(list_index), mstype.bool_)
|
|
551
|
+
return index_tensor
|
|
552
|
+
|
|
553
|
+
# convert bool to int if index is mixture of (bool, int)
|
|
554
|
+
if all_bool_or_int:
|
|
555
|
+
new_index = []
|
|
556
|
+
for idx in list_index:
|
|
557
|
+
if isinstance(idx, bool):
|
|
558
|
+
new_idx = int(idx)
|
|
559
|
+
new_index.append(new_idx)
|
|
560
|
+
else:
|
|
561
|
+
new_index.append(idx)
|
|
562
|
+
index_tensor = TupleToTensor()(tuple(new_index), mstype.int64)
|
|
563
|
+
return index_tensor
|
|
564
|
+
|
|
565
|
+
return None
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
class _TensorIndexGetitem(base.TensorIndexGetitem_):
|
|
569
|
+
"""
|
|
570
|
+
Getting item of Tensor.
|
|
571
|
+
|
|
572
|
+
Args:
|
|
573
|
+
data (Tensor): A tuple to be sliced.
|
|
574
|
+
index: Index of tensor.
|
|
575
|
+
|
|
576
|
+
Returns:
|
|
577
|
+
Type is the same as the element type of data.
|
|
578
|
+
"""
|
|
579
|
+
|
|
580
|
+
def __call__(self, *args):
|
|
581
|
+
pass
|
|
582
|
+
|
|
583
|
+
_tensor_index_getitem = _TensorIndexGetitem('tensor_index_getitem')
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
def tensor_index_by_slice(data, slice_index):
|
|
587
|
+
"""Tensor getitem by a slice."""
|
|
588
|
+
return _tensor_index_getitem(data, slice_index)
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
def tensor_index_by_number(data, number_index):
|
|
592
|
+
"""Tensor getitem by a Number which may be integer/float/bool value"""
|
|
593
|
+
if isinstance(number_index, bool):
|
|
594
|
+
return _tensor_index_by_bool(data, number_index)
|
|
595
|
+
if isinstance(number_index, int):
|
|
596
|
+
return _tensor_index_by_integer(data, number_index)
|
|
597
|
+
exp_msg = const_utils.gen_exception_msg(
|
|
598
|
+
"Number index of tensor must be int or bool, but got {}.", number_index)
|
|
599
|
+
return const_utils.raise_index_error(exp_msg)
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
def _tensor_index_by_bool(data, bool_value):
|
|
603
|
+
"""Tensor getitem by a single bool value"""
|
|
604
|
+
min_data_dim, max_data_dim = 0, 7
|
|
605
|
+
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
606
|
+
output = data
|
|
607
|
+
if bool_value:
|
|
608
|
+
output = F.expand_dims(data, 0)
|
|
609
|
+
elif not F.is_sequence_value_unknown(F.shape(data)):
|
|
610
|
+
return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.")
|
|
611
|
+
return output
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
def get_stride_info_from_integer(int_index):
|
|
615
|
+
"""Convert integer to slice"""
|
|
616
|
+
begin_strides = (int_index,)
|
|
617
|
+
end_strides = (int_index + 1,)
|
|
618
|
+
step_strides = (1,)
|
|
619
|
+
return begin_strides, end_strides, step_strides
|
|
620
|
+
|
|
621
|
+
|
|
622
|
+
def _tensor_index_by_integer(data, int_index):
|
|
623
|
+
"""Tensor getitem by a single integer number"""
|
|
624
|
+
begin_strides, end_strides, step_strides = get_stride_info_from_integer(int_index)
|
|
625
|
+
|
|
626
|
+
shrink_axis_mask = 1
|
|
627
|
+
begin_mask = 0
|
|
628
|
+
end_mask = 0
|
|
629
|
+
for i in range(2, 8):
|
|
630
|
+
begin_mask += 2 ** i
|
|
631
|
+
end_mask += 2 ** i
|
|
632
|
+
return strided_slice(data, begin_strides, end_strides, step_strides, begin_mask, end_mask, 0, 0, shrink_axis_mask)
|
|
633
|
+
|
|
634
|
+
def _check_dim_shape_valid(data, tensor_index):
|
|
635
|
+
"""check dim and shape of tensor_index for tensor(bool) indexing"""
|
|
636
|
+
if data.ndim < tensor_index.ndim:
|
|
637
|
+
raise IndexError(f"The dim of index cannot be greater than indexed data, but got "
|
|
638
|
+
f"dim of index:{tensor_index.ndim}, dim of data:{data.ndim}")
|
|
639
|
+
if data.shape[:tensor_index.ndim] != tensor_index.shape[:]:
|
|
640
|
+
raise IndexError(f"The shape of index {tensor_index.shape} does not match the shape "
|
|
641
|
+
f"of the indexed data {data.shape}")
|
|
642
|
+
|
|
643
|
+
|
|
644
|
+
def tensor_index_by_bool_tensor(data, tensor_index):
|
|
645
|
+
"""Tensor getitem by a bool tensor"""
|
|
646
|
+
if not F.is_sequence_value_unknown(F.shape(data)):
|
|
647
|
+
_check_dim_shape_valid(data, tensor_index)
|
|
648
|
+
tensor_index = tensor_index.nonzero()
|
|
649
|
+
return F.gather_nd(data, tensor_index)
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
def tensor_index_by_tensor(data, tensor_index):
|
|
653
|
+
"""Tensor getitem by a single tensor"""
|
|
654
|
+
min_data_dim, max_data_dim = 0, 7
|
|
655
|
+
if not F.is_sequence_value_unknown(F.shape(data)):
|
|
656
|
+
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
657
|
+
if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Int):
|
|
658
|
+
tensor_index = F.select(tensor_index < 0, tensor_index + F.shape(data)[0], tensor_index)
|
|
659
|
+
return F.gather(data, tensor_index, 0)
|
|
660
|
+
if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Bool):
|
|
661
|
+
return tensor_index_by_bool_tensor(data, tensor_index)
|
|
662
|
+
exp_msg = const_utils.gen_exception_msg(
|
|
663
|
+
"The tensor index must be int or bool type, but got {}.", F.dtype(tensor_index))
|
|
664
|
+
const_utils.raise_index_error(exp_msg)
|
|
665
|
+
return data
|
|
666
|
+
|
|
667
|
+
|
|
668
|
+
def tensor_index_by_list(data, list_index):
|
|
669
|
+
"""Tensor getitem by list of int and bool"""
|
|
670
|
+
min_data_dim, max_data_dim = 1, 8
|
|
671
|
+
if F.isconstant(data.ndim):
|
|
672
|
+
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
|
673
|
+
|
|
674
|
+
data_shape = F.shape(data)
|
|
675
|
+
if F.isconstant(data_shape[0]) and all(isinstance(i, bool) for i in list_index):
|
|
676
|
+
if data_shape[0] != len(list_index):
|
|
677
|
+
raise IndexError(
|
|
678
|
+
f'dimension is {data_shape[0]} but corresponding boolean dimension is {len(list_index)}')
|
|
679
|
+
tensor_index = Tensor(list_index).nonzero()
|
|
680
|
+
return F.gather_nd(data, tensor_index)
|
|
681
|
+
|
|
682
|
+
if not list_index:
|
|
683
|
+
const_utils.raise_index_error("When tensor is indexed by list, the list can't be empty.")
|
|
684
|
+
|
|
685
|
+
index_tensor = _convert_list_index_to_tensor(list_index)
|
|
686
|
+
if index_tensor is not None:
|
|
687
|
+
return tensor_index_by_tensor(data, index_tensor)
|
|
688
|
+
|
|
689
|
+
tuple_index_new = ()
|
|
690
|
+
for index in list_index:
|
|
691
|
+
tuple_index_new += (index,)
|
|
692
|
+
return tensor_index_by_tuple(data, tuple_index_new)
|
|
693
|
+
|
|
694
|
+
|
|
695
|
+
def judge_tuple_index_dim_check_error(index_dim, data_dim):
|
|
696
|
+
"""raise IndexError when tuple_index's dim is invalid"""
|
|
697
|
+
if index_dim > data_dim:
|
|
698
|
+
raise IndexError(f"The dim of index cannot be greater than indexed data, but got "
|
|
699
|
+
f"dim of index:{index_dim}, dim of data:{data_dim}")
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
def judge_tuple_index_dim(data, tuple_index):
|
|
703
|
+
"""Judge whether tuple_index's dim is valid"""
|
|
704
|
+
data_dim = data.ndim
|
|
705
|
+
index_dim = 0
|
|
706
|
+
for index in tuple_index:
|
|
707
|
+
if isinstance(toptypeof(index), mstype.TensorType) and index.dtype == mstype.bool_:
|
|
708
|
+
index_dim += index.ndim
|
|
709
|
+
elif not isinstance(toptypeof(index), (mstype.NoneType, mstype.Ellipsis_, mstype.Bool)):
|
|
710
|
+
index_dim += 1
|
|
711
|
+
judge_tuple_index_dim_check_error(index_dim, data_dim)
|
|
712
|
+
|
|
713
|
+
|
|
714
|
+
def tensor_index_by_tuple(data, tuple_index):
|
|
715
|
+
"""Tensor getitem by tuple of various types with None"""
|
|
716
|
+
if not tuple_index:
|
|
717
|
+
return data
|
|
718
|
+
|
|
719
|
+
if not F.is_sequence_value_unknown(F.shape(data)):
|
|
720
|
+
judge_tuple_index_dim(data, tuple_index)
|
|
721
|
+
tuple_index, zero_index, non_zero_shapes = _handle_bool_tensor(tuple_index)
|
|
722
|
+
for non_zero_shape in non_zero_shapes:
|
|
723
|
+
if 0 in non_zero_shape:
|
|
724
|
+
tuple_index = zero_index
|
|
725
|
+
break
|
|
726
|
+
|
|
727
|
+
return _tensor_index_getitem(data, tuple_index)
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
def get_slice_stride(slice_index, dim_size):
|
|
731
|
+
"""Get slice stride info"""
|
|
732
|
+
start = slice_get_item(slice_index, "start")
|
|
733
|
+
stop = slice_get_item(slice_index, "stop")
|
|
734
|
+
step = slice_get_item(slice_index, "step")
|
|
735
|
+
|
|
736
|
+
if start is None:
|
|
737
|
+
start = 0
|
|
738
|
+
if stop is None:
|
|
739
|
+
stop = dim_size
|
|
740
|
+
if step is None:
|
|
741
|
+
step = 1
|
|
742
|
+
|
|
743
|
+
if isinstance(start, Tensor):
|
|
744
|
+
start = int(start)
|
|
745
|
+
|
|
746
|
+
if isinstance(stop, Tensor):
|
|
747
|
+
stop = int(stop)
|
|
748
|
+
|
|
749
|
+
if isinstance(step, Tensor):
|
|
750
|
+
step = int(step)
|
|
751
|
+
|
|
752
|
+
return start, stop, step
|
|
753
|
+
|
|
754
|
+
|
|
755
|
+
def cal_tuple_slice_mask(data_shape, tuple_index):
|
|
756
|
+
"""calculate the strided_slice begin and end mask"""
|
|
757
|
+
begin_mask = 0
|
|
758
|
+
end_mask = 0
|
|
759
|
+
for i, slice_index in enumerate(tuple_index):
|
|
760
|
+
if isinstance(slice_index, slice):
|
|
761
|
+
begin_mask += 2 ** i if slice_get_item(slice_index, "start") is None else 0
|
|
762
|
+
end_mask += 2 ** i if slice_get_item(slice_index, "stop") is None else 0
|
|
763
|
+
for i in range(len(tuple_index), len(data_shape)):
|
|
764
|
+
begin_mask += 2 ** i
|
|
765
|
+
end_mask += 2 ** i
|
|
766
|
+
return begin_mask, end_mask
|
|
767
|
+
|
|
768
|
+
|
|
769
|
+
def _generate_indices_from_tuple_of_tensor(tuple_index, op_name):
|
|
770
|
+
"""Generate an indices tensor from a tuple of tensor."""
|
|
771
|
+
indexes_types = hyper_map(F.dtype, tuple_index)
|
|
772
|
+
const_utils.check_types_valid(indexes_types, mstype.int_type, op_name)
|
|
773
|
+
tensor_index_shape = hyper_map(F.shape, tuple_index)
|
|
774
|
+
broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name)
|
|
775
|
+
if len(broadcast_shape) < 2:
|
|
776
|
+
broadcast_shape = (1,) + broadcast_shape
|
|
777
|
+
broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index)
|
|
778
|
+
new_broadcast_tensors = ()
|
|
779
|
+
for tensor in broadcast_tensors:
|
|
780
|
+
new_broadcast_tensors += (F.cast(tensor, mstype.int64),)
|
|
781
|
+
indices = stack(new_broadcast_tensors)
|
|
782
|
+
return indices
|
|
783
|
+
|
|
784
|
+
|
|
785
|
+
def parse_check_slice_index(index_out, dim_size):
|
|
786
|
+
""" Parse and check slice index """
|
|
787
|
+
has_false = False
|
|
788
|
+
start, stop, step = const_utils.normalize_slice(index_out, dim_size)
|
|
789
|
+
if F.isconstant(start) and F.isconstant(stop) and F.isconstant(step):
|
|
790
|
+
has_false = const_utils.check_slice_empty(start, stop, step)
|
|
791
|
+
return has_false
|
|
792
|
+
|
|
793
|
+
|
|
794
|
+
def _generate_indices_from_tuple(data, tuple_index, op_name, fancy_position):
|
|
795
|
+
"""Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor."""
|
|
796
|
+
data_shape = F.shape(data)
|
|
797
|
+
tensor_indexes, slice_indexes = [], []
|
|
798
|
+
indexes_types = hyper_map(toptypeof, tuple_index)
|
|
799
|
+
slice_positions, _, _, int_positions, _, tensor_positions, sequence_positions = \
|
|
800
|
+
const_utils.get_pos_of_indexes_types(indexes_types, op_name)
|
|
801
|
+
tuple_index_new, slice_shapes = (), ()
|
|
802
|
+
for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
|
|
803
|
+
if i in int_positions:
|
|
804
|
+
int_index = const_utils.check_range(index, dim_size)
|
|
805
|
+
tensor_index = F.scalar_to_tensor(int_index, mstype.int64)
|
|
806
|
+
tuple_index_new += (tensor_index,)
|
|
807
|
+
tensor_indexes.append(tensor_index)
|
|
808
|
+
tensor_positions += (i,)
|
|
809
|
+
elif i in sequence_positions:
|
|
810
|
+
tensor_index = const_utils.sequence_to_index(index, dim_size)
|
|
811
|
+
tuple_index_new += (tensor_index,)
|
|
812
|
+
tensor_indexes.append(tensor_index)
|
|
813
|
+
tensor_positions += (i,)
|
|
814
|
+
elif i in tensor_positions:
|
|
815
|
+
invalid = const_utils.check_type_invalid(F.dtype(index), mstype.int_type)
|
|
816
|
+
if invalid:
|
|
817
|
+
exp_msg = const_utils.gen_exception_msg(
|
|
818
|
+
"The tensor element in tuple index must be int or bool type, but got {}.", F.dtype(index))
|
|
819
|
+
const_utils.raise_index_error(exp_msg)
|
|
820
|
+
tensor_index = F.cast(index, mstype.int64)
|
|
821
|
+
tuple_index_new += (tensor_index,)
|
|
822
|
+
tensor_indexes.append(tensor_index)
|
|
823
|
+
elif i in slice_positions:
|
|
824
|
+
if parse_check_slice_index(index, dim_size):
|
|
825
|
+
return False
|
|
826
|
+
slice_ele_list_index = const_utils.transform_slice_to_ele_list(index, dim_size)
|
|
827
|
+
slice_shapes += (len(slice_ele_list_index),)
|
|
828
|
+
tuple_index_new += (slice_ele_list_index,)
|
|
829
|
+
slice_indexes.append(slice_ele_list_index)
|
|
830
|
+
|
|
831
|
+
tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
|
|
832
|
+
broadcast_shape, index_tensor_new_shape, final_shape, fancy_position = \
|
|
833
|
+
const_utils.generate_index_info_from_tuple_of_mixed_tensors(tensor_positions, tensor_indexes_shapes,
|
|
834
|
+
slice_shapes, op_name, fancy_position)
|
|
835
|
+
|
|
836
|
+
final_index_tensors = []
|
|
837
|
+
slice_cnt = 0
|
|
838
|
+
for i, index in enumerate(tuple_index_new):
|
|
839
|
+
if i in tensor_positions:
|
|
840
|
+
transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape,
|
|
841
|
+
index)
|
|
842
|
+
final_index_tensors.append(transform_tensor)
|
|
843
|
+
elif i in slice_positions:
|
|
844
|
+
slice_index_tensor = convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape,
|
|
845
|
+
slice_shapes, fancy_position)
|
|
846
|
+
final_index_tensors.append(slice_index_tensor)
|
|
847
|
+
slice_cnt += 1
|
|
848
|
+
|
|
849
|
+
indices = stack(final_index_tensors)
|
|
850
|
+
return indices
|
|
851
|
+
|
|
852
|
+
|
|
853
|
+
def sequence_to_tensor(value, dtype):
|
|
854
|
+
"""Generate an updates tensor from a tuple, can only handle 1-D tensor/non-tensor mixtures."""
|
|
855
|
+
value_types = hyper_map(toptypeof, value)
|
|
856
|
+
value_elements_type = const_utils.check_value_elements(value_types)
|
|
857
|
+
|
|
858
|
+
if value_elements_type == const_utils.ALL_TENSOR:
|
|
859
|
+
value = F.stack(value).astype(dtype)
|
|
860
|
+
elif value_elements_type == const_utils.NO_TENSOR:
|
|
861
|
+
if isinstance(value, list):
|
|
862
|
+
value = tuple(value)
|
|
863
|
+
|
|
864
|
+
if dtype == mstype.float16:
|
|
865
|
+
value = TupleToTensor()(value, mstype.float32)
|
|
866
|
+
value = F.cast(value, dtype)
|
|
867
|
+
else:
|
|
868
|
+
value = TupleToTensor()(value, dtype)
|
|
869
|
+
else:
|
|
870
|
+
new_value = ()
|
|
871
|
+
for ele in value:
|
|
872
|
+
ele = ele if isinstance(ele, Tensor) else const_utils.make_tensor(ele, dtype)
|
|
873
|
+
new_value += (ele,)
|
|
874
|
+
value = F.stack(new_value).astype(dtype)
|
|
875
|
+
return value
|
|
876
|
+
|
|
877
|
+
|
|
878
|
+
def _generate_updates_from_sequence(data, index, value, op_type):
|
|
879
|
+
"""Generate an updates tensor from a tuple, can only handle 1-D tensor/non-tensor mixtures."""
|
|
880
|
+
value = sequence_to_tensor(value, F.dtype(data))
|
|
881
|
+
if op_type == const_utils.SET_ITEM_BY_NON_TENSOR:
|
|
882
|
+
return value
|
|
883
|
+
return _generate_updates_from_tensor(data, index, value, op_type)
|
|
884
|
+
|
|
885
|
+
|
|
886
|
+
def _generate_updates_from_tensor(data, index, value, op_type):
|
|
887
|
+
"""Generate an updates tensor from a tensor."""
|
|
888
|
+
value = value.astype(data.dtype)
|
|
889
|
+
updates_shape = const_utils.generate_updates_shape(data.shape, index.shape, op_type)
|
|
890
|
+
updates = ops.broadcast_to(value, updates_shape)
|
|
891
|
+
return updates
|
|
892
|
+
|
|
893
|
+
|
|
894
|
+
# Tensor getitem implementations are above this line, setitem implementations below.
|
|
895
|
+
|
|
896
|
+
def _tensor_index_transfer(index, broadcast_shape, final_shape, new_shape):
|
|
897
|
+
"""Transform tuple index tensor to the required."""
|
|
898
|
+
if 0 in final_shape:
|
|
899
|
+
return F.fill(index.dtype, final_shape, 0)
|
|
900
|
+
|
|
901
|
+
if broadcast_shape == ():
|
|
902
|
+
# broadcast_to () is not support on Ascend
|
|
903
|
+
item = index
|
|
904
|
+
else:
|
|
905
|
+
item = F.broadcast_to(index, broadcast_shape)
|
|
906
|
+
item = F.reshape(item, new_shape)
|
|
907
|
+
return F.broadcast_to(item, final_shape)
|
|
908
|
+
|
|
909
|
+
|
|
910
|
+
def reshape_with_check(x, new_shape):
|
|
911
|
+
if isinstance(new_shape, Tensor):
|
|
912
|
+
new_shape = TensorToTuple()(new_shape)
|
|
913
|
+
return F.reshape(x, new_shape)
|
|
914
|
+
|
|
915
|
+
|
|
916
|
+
class _TensorIndexSetitem(base.TensorIndexSetitem_):
|
|
917
|
+
"""
|
|
918
|
+
Getting item of Tensor.
|
|
919
|
+
|
|
920
|
+
Args:
|
|
921
|
+
data (Tensor): A tuple to be sliced.
|
|
922
|
+
index: Index of tensor.
|
|
923
|
+
|
|
924
|
+
Returns:
|
|
925
|
+
Type is the same as the element type of data.
|
|
926
|
+
"""
|
|
927
|
+
|
|
928
|
+
def __call__(self, *args):
|
|
929
|
+
pass
|
|
930
|
+
|
|
931
|
+
|
|
932
|
+
_tensor_index_setitem = _TensorIndexSetitem('tensor_index_setitem')
|
|
933
|
+
|
|
934
|
+
|
|
935
|
+
def tensor_setitem_by_slice(self, index, value):
|
|
936
|
+
"""Set a tensor item by slice."""
|
|
937
|
+
indices, value_shape, start, stop, step, value = _tensor_index_setitem(
|
|
938
|
+
self, index, value)
|
|
939
|
+
if start == stop:
|
|
940
|
+
return self
|
|
941
|
+
value = F.broadcast_to(value, value_shape)
|
|
942
|
+
if not const_utils.is_ascend() and step == 1:
|
|
943
|
+
start = (start,)
|
|
944
|
+
stop = (stop,)
|
|
945
|
+
step = (step,)
|
|
946
|
+
return copy_slice(self, value, start, stop, step)
|
|
947
|
+
return F.tensor_scatter_update(self, indices, value)
|
|
948
|
+
|
|
949
|
+
|
|
950
|
+
def tensor_setitem_by_ellipsis(self, index, value):
|
|
951
|
+
if isinstance(value, (int, float, bool)):
|
|
952
|
+
return tensor_setitem_by_ellipsis_with_number(self, value)
|
|
953
|
+
if isinstance(value, Tensor):
|
|
954
|
+
return tensor_setitem_by_ellipsis_with_tensor(self, value)
|
|
955
|
+
return tensor_setitem_by_ellipsis_with_sequence(self, value)
|
|
956
|
+
|
|
957
|
+
|
|
958
|
+
def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
|
|
959
|
+
"""Set a tensor item by an int tensor with a tensor."""
|
|
960
|
+
if F.rank(index) == 0:
|
|
961
|
+
index = F.expand_dims(index, -1)
|
|
962
|
+
|
|
963
|
+
data_shape = F.shape(data)
|
|
964
|
+
updates_shape = index.shape + data_shape[1:]
|
|
965
|
+
value = F.cast(value, F.dtype(data))
|
|
966
|
+
updates = ops.broadcast_to(value, updates_shape)
|
|
967
|
+
first_val = data_shape[0]
|
|
968
|
+
index = F.select(index < 0, index + first_val, index)
|
|
969
|
+
index = F.expand_dims(index, -1)
|
|
970
|
+
if is_parameter(data):
|
|
971
|
+
F.scatter_nd_update(data, index, updates)
|
|
972
|
+
return data
|
|
973
|
+
return F.tensor_scatter_update(data, index, updates)
|
|
974
|
+
|
|
975
|
+
|
|
976
|
+
def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
|
|
977
|
+
"""Set a tensor item by a bool tensor with a tensor."""
|
|
978
|
+
index = index.reshape(const_utils.generate_padding_shape(index.shape, len(data.shape)))
|
|
979
|
+
index = F.broadcast_to(index, data.shape)
|
|
980
|
+
value = F.cast(value, F.dtype(data))
|
|
981
|
+
value = value.reshape(const_utils.generate_padding_shape(value.shape, len(data.shape)))
|
|
982
|
+
value = F.broadcast_to(value, data.shape)
|
|
983
|
+
result = F.select(index, value, data)
|
|
984
|
+
return result
|
|
985
|
+
|
|
986
|
+
|
|
987
|
+
def tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
|
|
988
|
+
"""setitem by tensor index(dtype is int or bool) with tensor as value"""
|
|
989
|
+
index_dtype = F.dtype(index)
|
|
990
|
+
tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype)
|
|
991
|
+
if tensor_dtype == const_utils.INT_:
|
|
992
|
+
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
|
|
993
|
+
|
|
994
|
+
return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
|
|
995
|
+
|
|
996
|
+
|
|
997
|
+
def tensor_setitem_by_tensor_with_number(data, index, value):
|
|
998
|
+
value = F.cast(value, F.dtype(data))
|
|
999
|
+
return tensor_setitem_by_tensor_with_tensor(data, index, value)
|
|
1000
|
+
|
|
1001
|
+
|
|
1002
|
+
def tensor_setitem_by_tensor_with_sequence(data, index, value):
|
|
1003
|
+
"""Assigns the tensor by tensor with tuple value."""
|
|
1004
|
+
value = sequence_to_tensor(value, F.dtype(data))
|
|
1005
|
+
return tensor_setitem_by_tensor_with_tensor(data, index, value)
|
|
1006
|
+
|
|
1007
|
+
|
|
1008
|
+
def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
|
|
1009
|
+
"""Assigns the tensor by tuple with number value."""
|
|
1010
|
+
value = F.cast(value, F.dtype(data))
|
|
1011
|
+
return tensor_setitem_by_tuple_with_tensor(data, tuple_index, value)
|
|
1012
|
+
|
|
1013
|
+
|
|
1014
|
+
def tensor_setitem_by_list(data, index, value):
|
|
1015
|
+
"""list indices will be converted to tuple or tensor based on its contents."""
|
|
1016
|
+
index_tensor = _convert_list_index_to_tensor(index)
|
|
1017
|
+
if index_tensor is not None:
|
|
1018
|
+
return tensor_setitem_by_tensor_with_tensor(data, index_tensor, value)
|
|
1019
|
+
|
|
1020
|
+
return tensor_setitem_by_tuple_with_tensor(data, tuple(index), value)
|
|
1021
|
+
|
|
1022
|
+
|
|
1023
|
+
|
|
1024
|
+
class _PreSetitemByTuple(base.PreSetitemByTuple_):
|
|
1025
|
+
"""
|
|
1026
|
+
Getting item of Tensor.
|
|
1027
|
+
|
|
1028
|
+
Args:
|
|
1029
|
+
data (Tensor): A tuple to be sliced.
|
|
1030
|
+
index: Index of tensor.
|
|
1031
|
+
|
|
1032
|
+
Returns:
|
|
1033
|
+
Type is the same as the element type of data.
|
|
1034
|
+
"""
|
|
1035
|
+
|
|
1036
|
+
def __init__(self, name):
|
|
1037
|
+
"""Initialize _PreSetitemByTuple."""
|
|
1038
|
+
base.PreSetitemByTuple_.__init__(self, name)
|
|
1039
|
+
|
|
1040
|
+
def __call__(self, *args):
|
|
1041
|
+
pass
|
|
1042
|
+
|
|
1043
|
+
|
|
1044
|
+
_pre_setitem_by_tuple = _PreSetitemByTuple('pre_setitem_by_tuple')
|
|
1045
|
+
|
|
1046
|
+
|
|
1047
|
+
class _HandleBoolTensor(base.HandleBoolTensor_):
|
|
1048
|
+
"""
|
|
1049
|
+
Getting item of Tensor.
|
|
1050
|
+
|
|
1051
|
+
Args:
|
|
1052
|
+
data (Tensor): A tuple to be sliced.
|
|
1053
|
+
index: Index of tensor.
|
|
1054
|
+
|
|
1055
|
+
Returns:
|
|
1056
|
+
Type is the same as the element type of data.
|
|
1057
|
+
"""
|
|
1058
|
+
|
|
1059
|
+
def __init__(self, name):
|
|
1060
|
+
"""Initialize _HandleBoolTensor."""
|
|
1061
|
+
base.HandleBoolTensor_.__init__(self, name)
|
|
1062
|
+
|
|
1063
|
+
def __call__(self, *args):
|
|
1064
|
+
pass
|
|
1065
|
+
|
|
1066
|
+
|
|
1067
|
+
_handle_bool_tensor = _HandleBoolTensor('handle_bool_tensor')
|
|
1068
|
+
|
|
1069
|
+
|
|
1070
|
+
def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
|
1071
|
+
"""Assigns the tensor by tuple with tensor value."""
|
|
1072
|
+
if const_utils.use_copy_slice(tuple_index) and not const_utils.is_ascend():
|
|
1073
|
+
dim1_start, dim1_stop, _ = const_utils.normalize_slice(
|
|
1074
|
+
tuple_index[1], data.shape[1])
|
|
1075
|
+
if isinstance(dim1_start, Tensor):
|
|
1076
|
+
dim1_start = int(dim1_start)
|
|
1077
|
+
if isinstance(dim1_stop, Tensor):
|
|
1078
|
+
dim1_stop = int(dim1_stop)
|
|
1079
|
+
if dim1_stop - dim1_start <= 0:
|
|
1080
|
+
return data
|
|
1081
|
+
dim0_start = tuple_index[0] if tuple_index[0] >= 0 else tuple_index[0] + data.shape[0]
|
|
1082
|
+
start = (dim0_start, dim1_start)
|
|
1083
|
+
stop = (dim0_start + 1, dim1_stop)
|
|
1084
|
+
step = (1, 1)
|
|
1085
|
+
value_shape = (dim1_stop - dim1_start,) + data.shape[2:]
|
|
1086
|
+
value = F.broadcast_to(value, value_shape)
|
|
1087
|
+
return copy_slice(data, value.astype(data.dtype), start, stop, step)
|
|
1088
|
+
tuple_index, _, non_zero_shapes = _handle_bool_tensor(tuple_index)
|
|
1089
|
+
|
|
1090
|
+
for non_zero_shape in non_zero_shapes:
|
|
1091
|
+
if 0 in non_zero_shape:
|
|
1092
|
+
return data
|
|
1093
|
+
value = value.astype(data.dtype)
|
|
1094
|
+
special_index, tuple_index, new_value_shape, idx_advanced, _broadcast_data_shape \
|
|
1095
|
+
= _pre_setitem_by_tuple(data, tuple_index, value)
|
|
1096
|
+
if special_index == 0:
|
|
1097
|
+
return data
|
|
1098
|
+
value = F.reshape(value, new_value_shape)
|
|
1099
|
+
if not tuple_index or special_index == 1:
|
|
1100
|
+
data[True] = value
|
|
1101
|
+
return data
|
|
1102
|
+
|
|
1103
|
+
empty_broadcast_data_shape = False
|
|
1104
|
+
if isinstance(_broadcast_data_shape, Tensor) and _broadcast_data_shape == Tensor([0]):
|
|
1105
|
+
empty_broadcast_data_shape = True
|
|
1106
|
+
if isinstance(_broadcast_data_shape, tuple) and not _broadcast_data_shape:
|
|
1107
|
+
empty_broadcast_data_shape = True
|
|
1108
|
+
indices = _tensor_index_setitem(
|
|
1109
|
+
data, tuple_index, value, idx_advanced, empty_broadcast_data_shape)
|
|
1110
|
+
|
|
1111
|
+
updates = _generate_updates_from_tensor(
|
|
1112
|
+
data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
|
|
1113
|
+
if is_parameter(data):
|
|
1114
|
+
F.scatter_nd_update(data, indices, updates)
|
|
1115
|
+
return data
|
|
1116
|
+
return F.tensor_scatter_update(data, indices, updates)
|
|
1117
|
+
|
|
1118
|
+
def tensor_itemset_by_tuple_with_tensor(data, tuple_index, value):
|
|
1119
|
+
"""Assigns the tensor by tuple with tensor value."""
|
|
1120
|
+
op_name = const_utils.TENSOR_SETITEM
|
|
1121
|
+
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
|
1122
|
+
|
|
1123
|
+
if const_utils.use_copy_slice(tuple_index) and not const_utils.is_ascend():
|
|
1124
|
+
dim1_start, dim1_stop, _ = const_utils.normalize_slice(tuple_index[1], data.shape[1])
|
|
1125
|
+
if isinstance(dim1_start, Tensor):
|
|
1126
|
+
dim1_start = int(dim1_start)
|
|
1127
|
+
if isinstance(dim1_stop, Tensor):
|
|
1128
|
+
dim1_stop = int(dim1_stop)
|
|
1129
|
+
if dim1_stop - dim1_start <= 0:
|
|
1130
|
+
return data
|
|
1131
|
+
dim0_start = tuple_index[0] if tuple_index[0] >= 0 else tuple_index[0] + data.shape[0]
|
|
1132
|
+
start = (dim0_start, dim1_start)
|
|
1133
|
+
stop = (dim0_start + 1, dim1_stop)
|
|
1134
|
+
step = (1, 1)
|
|
1135
|
+
value_shape = (dim1_stop - dim1_start,) + data.shape[2:]
|
|
1136
|
+
value = F.broadcast_to(value, value_shape)
|
|
1137
|
+
return copy_slice(data, value.astype(data.dtype), start, stop, step)
|
|
1138
|
+
tuple_index, value, idx_advanced = remove_expanded_dims(tuple_index, F.shape(data), value)
|
|
1139
|
+
|
|
1140
|
+
if tuple_index is False:
|
|
1141
|
+
return data
|
|
1142
|
+
if len(tuple_index) == 1:
|
|
1143
|
+
data[tuple_index[0]] = value
|
|
1144
|
+
return data
|
|
1145
|
+
indexes_types = hyper_map(toptypeof, tuple_index)
|
|
1146
|
+
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
|
|
1147
|
+
|
|
1148
|
+
if contain_type == const_utils.ALL_TENSOR:
|
|
1149
|
+
indices = _generate_indices_from_tuple_of_tensor(tuple_index, op_name)
|
|
1150
|
+
else:
|
|
1151
|
+
indices = _generate_indices_from_tuple(data, tuple_index, op_name, idx_advanced)
|
|
1152
|
+
if indices is False:
|
|
1153
|
+
return data
|
|
1154
|
+
updates = _generate_updates_from_tensor(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
|
|
1155
|
+
return F.tensor_scatter_update(data, indices, updates)
|
|
1156
|
+
|
|
1157
|
+
|
|
1158
|
+
def tensor_setitem_by_tuple_with_sequence(data, tuple_index, value):
|
|
1159
|
+
value = sequence_to_tensor(value, F.dtype(data))
|
|
1160
|
+
return tensor_setitem_by_tuple_with_tensor(data, tuple_index, value)
|
|
1161
|
+
|
|
1162
|
+
|
|
1163
|
+
def tensor_setitem_by_number_with_number(data, index, value):
|
|
1164
|
+
"""Assigns the tensor by number with number value."""
|
|
1165
|
+
data_shape = F.shape(data)
|
|
1166
|
+
dim_size = data_shape[0]
|
|
1167
|
+
if index < 0:
|
|
1168
|
+
index += dim_size
|
|
1169
|
+
if index < -dim_size or index >= dim_size:
|
|
1170
|
+
raise IndexError(f'index {index} is out of bounds for axis 0 with size {dim_size}')
|
|
1171
|
+
index = F.cast(index, mstype.int64)
|
|
1172
|
+
index = F.reshape(index, (1, 1))
|
|
1173
|
+
|
|
1174
|
+
updates = F.cast(value, data.dtype)
|
|
1175
|
+
updates_shape = (1,) + data_shape[1:]
|
|
1176
|
+
updates = ops.broadcast_to(updates, updates_shape)
|
|
1177
|
+
|
|
1178
|
+
if is_parameter(data):
|
|
1179
|
+
F.scatter_nd_update(data, index, updates)
|
|
1180
|
+
return data
|
|
1181
|
+
return F.tensor_scatter_update(data, index, updates)
|
|
1182
|
+
|
|
1183
|
+
|
|
1184
|
+
def tensor_setitem_by_number_with_sequence(data, index, value):
|
|
1185
|
+
"""Assigns a list/tuple value to the tensor by slice."""
|
|
1186
|
+
value = sequence_to_tensor(value, F.dtype(data))
|
|
1187
|
+
return tensor_setitem_by_number_with_tensor(data, index, value)
|
|
1188
|
+
|
|
1189
|
+
|
|
1190
|
+
def tensor_setitem_by_number_with_tensor(data, index, value):
|
|
1191
|
+
return tensor_setitem_by_number_with_number(data, index, value)
|
|
1192
|
+
|
|
1193
|
+
|
|
1194
|
+
def tensor_setitem_by_ellipsis_with_number(data, value):
|
|
1195
|
+
"""Assigns the tensor by ellipsis with number value."""
|
|
1196
|
+
data_shape = F.shape(data)
|
|
1197
|
+
data_dtype = F.dtype(data)
|
|
1198
|
+
return F.fill(data_dtype, data_shape, value)
|
|
1199
|
+
|
|
1200
|
+
|
|
1201
|
+
def tensor_setitem_by_ellipsis_with_tensor(data, value):
|
|
1202
|
+
"""Assigns the tensor by ellipsis with tensor value."""
|
|
1203
|
+
data_shape = F.shape(data)
|
|
1204
|
+
data_dtype = F.dtype(data)
|
|
1205
|
+
value = value.astype(data_dtype)
|
|
1206
|
+
|
|
1207
|
+
value_shape = F.shape(value)
|
|
1208
|
+
|
|
1209
|
+
if len(value_shape) > len(data_shape):
|
|
1210
|
+
source_shape = data_shape
|
|
1211
|
+
else:
|
|
1212
|
+
source_shape = value_shape
|
|
1213
|
+
value = F.reshape(value, source_shape)
|
|
1214
|
+
data = F.broadcast_to(value, data_shape)
|
|
1215
|
+
return data
|
|
1216
|
+
|
|
1217
|
+
|
|
1218
|
+
def tensor_setitem_by_ellipsis_with_sequence(data, value):
|
|
1219
|
+
"""Assigns a list/tuple value to the tensor by ellipsis."""
|
|
1220
|
+
value = sequence_to_tensor(value, F.dtype(data))
|
|
1221
|
+
return tensor_setitem_by_ellipsis_with_tensor(data, value)
|
|
1222
|
+
|
|
1223
|
+
|
|
1224
|
+
def tensor_setitem_by_bool(data, index, value):
|
|
1225
|
+
"""Assigns a value to the tensor by boolean."""
|
|
1226
|
+
data_shape = F.shape(data)
|
|
1227
|
+
data_dtype = F.dtype(data)
|
|
1228
|
+
if not index:
|
|
1229
|
+
data_shape = (0,) + data_shape
|
|
1230
|
+
if isinstance(value, (list, tuple)):
|
|
1231
|
+
value = sequence_to_tensor(value, data_dtype)
|
|
1232
|
+
else:
|
|
1233
|
+
value = F.cast(value, data_dtype)
|
|
1234
|
+
|
|
1235
|
+
if index:
|
|
1236
|
+
value_shape = F.shape(value)
|
|
1237
|
+
if len(value_shape) > len(data_shape):
|
|
1238
|
+
source_shape = data_shape
|
|
1239
|
+
else:
|
|
1240
|
+
source_shape = value_shape
|
|
1241
|
+
value = F.reshape(value, source_shape)
|
|
1242
|
+
data = F.broadcast_to(value, data_shape)
|
|
1243
|
+
return data
|
|
1244
|
+
|
|
1245
|
+
|
|
1246
|
+
def tensor_in_sequence(x, y):
|
|
1247
|
+
"""Assigns whether a sequence contains the given tensor"""
|
|
1248
|
+
result = const_utils.scalar_to_tensor(False)
|
|
1249
|
+
for i in y:
|
|
1250
|
+
if isinstance(i, Tensor) and x.shape == i.shape and x.dtype == i.dtype:
|
|
1251
|
+
result = F.logical_or(F.equal(x, i).all(), result)
|
|
1252
|
+
return result
|
|
1253
|
+
|
|
1254
|
+
|
|
1255
|
+
@_primexpr
|
|
1256
|
+
def remove_expanded_dims_parse_bool_tensor_index(index_out, indices_out, shapes, cur_dim):
|
|
1257
|
+
""" Parse bool tensor index """
|
|
1258
|
+
index_out = index_out.nonzero()
|
|
1259
|
+
if index_out.shape[0] == 0:
|
|
1260
|
+
return None, shapes, cur_dim
|
|
1261
|
+
for i in range(index_out.shape[1]):
|
|
1262
|
+
out = index_out[:, i]
|
|
1263
|
+
indices_out += (out,)
|
|
1264
|
+
shapes.append(F.shape(out))
|
|
1265
|
+
cur_dim += 1
|
|
1266
|
+
return indices_out, shapes, cur_dim
|
|
1267
|
+
|
|
1268
|
+
|
|
1269
|
+
def remove_expanded_dims_parse_tensor_index(index_out, indices_out, shapes, cur_dim):
|
|
1270
|
+
""" Parse tensor index """
|
|
1271
|
+
if index_out.dtype == mstype.bool_:
|
|
1272
|
+
return remove_expanded_dims_parse_bool_tensor_index(index_out, indices_out, shapes, cur_dim)
|
|
1273
|
+
indices_out += (index_out,)
|
|
1274
|
+
shapes.append(F.shape(index_out))
|
|
1275
|
+
cur_dim += 1
|
|
1276
|
+
return indices_out, shapes, cur_dim
|
|
1277
|
+
|
|
1278
|
+
|
|
1279
|
+
def remove_expanded_dims(tuple_index, data_shape, value):
|
|
1280
|
+
"""Removes expanded dimensions in tuple_index and value."""
|
|
1281
|
+
not_expanded_dim = ()
|
|
1282
|
+
shapes = []
|
|
1283
|
+
has_true = False
|
|
1284
|
+
has_false = False
|
|
1285
|
+
has_sequence = False
|
|
1286
|
+
indices_out = () # with dimension expansion indices removed
|
|
1287
|
+
idx_tensor = -1 # index of the previous tensor
|
|
1288
|
+
idx_advanced = -1 # index of the first advanced index in expanded tensor
|
|
1289
|
+
cur_dim = 0 # current dimension of the data to be indexed
|
|
1290
|
+
|
|
1291
|
+
for i, v in enumerate(tuple_index):
|
|
1292
|
+
index_out = format_index(v, data_shape, cur_dim)
|
|
1293
|
+
|
|
1294
|
+
if index_out is None:
|
|
1295
|
+
not_expanded_dim += (False,)
|
|
1296
|
+
elif const_utils.is_slice(index_out):
|
|
1297
|
+
indices_out += (index_out,)
|
|
1298
|
+
not_expanded_dim += (True,)
|
|
1299
|
+
has_false = has_false or parse_check_slice_index(
|
|
1300
|
+
index_out, data_shape[cur_dim])
|
|
1301
|
+
cur_dim += 1
|
|
1302
|
+
elif isinstance(index_out, (Tensor, bool)): # advanced index
|
|
1303
|
+
if idx_advanced == -1:
|
|
1304
|
+
idx_advanced = len(not_expanded_dim)
|
|
1305
|
+
elif i - idx_tensor > 1:
|
|
1306
|
+
idx_advanced = 0
|
|
1307
|
+
idx_tensor = i
|
|
1308
|
+
if isinstance(index_out, Tensor):
|
|
1309
|
+
indices_out, shapes, cur_dim = \
|
|
1310
|
+
remove_expanded_dims_parse_tensor_index(index_out, indices_out, shapes, cur_dim)
|
|
1311
|
+
if indices_out is None:
|
|
1312
|
+
return False, value, 0
|
|
1313
|
+
if index_out.dtype != mstype.bool_ and F.rank(index_out) > 0:
|
|
1314
|
+
has_sequence = True
|
|
1315
|
+
has_true = has_true or index_out is True
|
|
1316
|
+
has_false = has_false or index_out is False
|
|
1317
|
+
else:
|
|
1318
|
+
const_utils.raise_index_error('invalid index type')
|
|
1319
|
+
|
|
1320
|
+
broadcast_shape = const_utils.generate_broadcast_shape(shapes, const_utils.TENSOR_SETITEM)
|
|
1321
|
+
if has_false:
|
|
1322
|
+
if F.shape_mul(broadcast_shape) != 1:
|
|
1323
|
+
const_utils.raise_index_error('unable to broadcast indices')
|
|
1324
|
+
indices_out = False
|
|
1325
|
+
else:
|
|
1326
|
+
expand_true = has_true and not (has_false or has_sequence) # whether to expand dimension at True
|
|
1327
|
+
tensor_index_ndim = len(broadcast_shape) # ndim of tensor indices
|
|
1328
|
+
rem_ndim = len(data_shape) - cur_dim # number of remaining dimensions in data not indexed
|
|
1329
|
+
not_expanded_dim, idx_advanced = const_utils.rem_not_expanded_dims(idx_advanced, expand_true,
|
|
1330
|
+
tensor_index_ndim,
|
|
1331
|
+
rem_ndim, not_expanded_dim)
|
|
1332
|
+
if not indices_out:
|
|
1333
|
+
indices_out = (True,)
|
|
1334
|
+
|
|
1335
|
+
value_shape = const_utils.filter_expanded_dims(F.shape(value), not_expanded_dim)
|
|
1336
|
+
value = F.reshape(value, value_shape)
|
|
1337
|
+
return indices_out, value, idx_advanced
|
|
1338
|
+
|
|
1339
|
+
|
|
1340
|
+
def format_index(idx, data_shape, cur_dim):
|
|
1341
|
+
"""Converts advanced index into tensor."""
|
|
1342
|
+
if isinstance(idx, (tuple, list)):
|
|
1343
|
+
idx = const_utils.sequence_to_index(idx, data_shape[cur_dim])
|
|
1344
|
+
elif isinstance(idx, int) and not isinstance(idx, bool):
|
|
1345
|
+
idx = const_utils.make_tensor(idx, mstype.int64, None, data_shape[cur_dim])
|
|
1346
|
+
elif isinstance(idx, Tensor):
|
|
1347
|
+
tensor_dtype = const_utils.get_index_tensor_dtype(idx.dtype)
|
|
1348
|
+
if tensor_dtype == const_utils.INT_:
|
|
1349
|
+
idx = F.select(idx < 0, idx + data_shape[cur_dim], idx)
|
|
1350
|
+
elif tensor_dtype == const_utils.BOOL_:
|
|
1351
|
+
# index with tensor(bool) type is processed in remove_expanded_dims()
|
|
1352
|
+
pass
|
|
1353
|
+
return idx
|
|
1354
|
+
|
|
1355
|
+
|
|
1356
|
+
@_primexpr
|
|
1357
|
+
def _check_shape_mul(shape):
|
|
1358
|
+
if F.shape_mul(shape) == 0:
|
|
1359
|
+
raise ValueError('zero-size tensors are not supported.')
|
|
1360
|
+
|
|
1361
|
+
|
|
1362
|
+
def reduce_(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None, where=True, dtype=None):
|
|
1363
|
+
"""
|
|
1364
|
+
Applies comparison based on cmp_fn and reduction based on reduce_fn.
|
|
1365
|
+
If cmp_fn is None, only reduction is performed.
|
|
1366
|
+
"""
|
|
1367
|
+
|
|
1368
|
+
shape = F.shape(a)
|
|
1369
|
+
ndim = F.rank(a)
|
|
1370
|
+
if dtype is None:
|
|
1371
|
+
dtype = F.dtype(a)
|
|
1372
|
+
axes = validator.check_axis_valid(axis, ndim)
|
|
1373
|
+
if initial is not None:
|
|
1374
|
+
if ((isinstance(initial, Tensor) and F.rank(initial) > 0) or
|
|
1375
|
+
not isinstance(initial, (int, float, bool, Tensor))):
|
|
1376
|
+
const_utils.raise_type_error('initial must be scalar')
|
|
1377
|
+
|
|
1378
|
+
_check_shape_mul(shape)
|
|
1379
|
+
|
|
1380
|
+
if initial is not None:
|
|
1381
|
+
if isinstance(initial, Tensor):
|
|
1382
|
+
initial = F.tile(initial, shape).astype(dtype)
|
|
1383
|
+
else:
|
|
1384
|
+
initial = F.fill(dtype, shape, initial)
|
|
1385
|
+
a = cmp_fn(a, initial)
|
|
1386
|
+
|
|
1387
|
+
if where is not None and not isinstance(where, Tensor):
|
|
1388
|
+
where = Tensor(where, dtype=mstype.bool_)
|
|
1389
|
+
|
|
1390
|
+
if where is not None and (where.shape or not where):
|
|
1391
|
+
if initial is None:
|
|
1392
|
+
const_utils.raise_value_error('initial value must be provided for where masks')
|
|
1393
|
+
ndim_orig = F.rank(a)
|
|
1394
|
+
# broadcasts input tensors
|
|
1395
|
+
shape_out = const_utils.infer_out_shape(F.shape(where), F.shape(a), F.shape(initial))
|
|
1396
|
+
where = where.astype(mstype.float32)
|
|
1397
|
+
where = F.broadcast_to(where, shape_out)
|
|
1398
|
+
where = where.astype(mstype.bool_)
|
|
1399
|
+
a = F.broadcast_to(a, shape_out)
|
|
1400
|
+
initial = F.broadcast_to(initial, shape_out)
|
|
1401
|
+
a = F.select(where, a, initial)
|
|
1402
|
+
axes = const_utils.real_axes(ndim_orig, F.rank(a), axes)
|
|
1403
|
+
|
|
1404
|
+
return reduce_fn(a, axes).astype(dtype)
|
|
1405
|
+
|
|
1406
|
+
|
|
1407
|
+
setattr(tensor_operator_registry, "reduce", reduce_)
|
|
1408
|
+
|
|
1409
|
+
|
|
1410
|
+
def check_indices(dims, indices, mode, allow_negative_index=True):
|
|
1411
|
+
"""Checks whether indices are out of bounds."""
|
|
1412
|
+
shape = F.shape(indices)
|
|
1413
|
+
dtype = F.dtype(indices)
|
|
1414
|
+
if not allow_negative_index:
|
|
1415
|
+
lowerbounds = F.fill(dtype, shape, 0)
|
|
1416
|
+
else:
|
|
1417
|
+
lowerbounds = F.fill(dtype, shape, -dims)
|
|
1418
|
+
upperbounds = F.fill(dtype, shape, dims - 1)
|
|
1419
|
+
out_of_lowerbounds = F.tensor_lt(indices, lowerbounds)
|
|
1420
|
+
out_of_upperbounds = F.tensor_gt(indices, upperbounds)
|
|
1421
|
+
if mode == 'raise':
|
|
1422
|
+
const_utils.raise_unimplemented_error('"raise" mode is not implemented')
|
|
1423
|
+
if mode == 'wrap':
|
|
1424
|
+
bounds = F.fill(dtype, shape, dims)
|
|
1425
|
+
quotient = F.tensor_floordiv(indices, bounds)
|
|
1426
|
+
prod = F.tensor_mul(bounds, quotient)
|
|
1427
|
+
return F.tensor_sub(indices, prod)
|
|
1428
|
+
zeros = F.fill(dtype, shape, 0)
|
|
1429
|
+
clipped = F.select(out_of_lowerbounds, zeros, indices)
|
|
1430
|
+
clipped = F.select(out_of_upperbounds, upperbounds, clipped)
|
|
1431
|
+
return clipped
|
|
1432
|
+
|
|
1433
|
+
|
|
1434
|
+
setattr(tensor_operator_registry, 'check_indices', check_indices)
|
|
1435
|
+
|
|
1436
|
+
|
|
1437
|
+
def convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape, slice_shapes, fancy_position):
|
|
1438
|
+
"""Convert a slice to a tensor."""
|
|
1439
|
+
shape = const_utils.compute_slice_shape(slice_shapes, len(broadcast_shape), slice_cnt, fancy_position)
|
|
1440
|
+
array = const_utils.make_tensor(index, mstype.int64).reshape(shape)
|
|
1441
|
+
reps = const_utils.compute_multiples(shape, final_shape)
|
|
1442
|
+
slice_index_tensor = F.tile(array, reps)
|
|
1443
|
+
return slice_index_tensor
|
|
1444
|
+
|
|
1445
|
+
|
|
1446
|
+
def check_coo_tensor_input_length(coo_tuple):
|
|
1447
|
+
"""Check length of coo tensor."""
|
|
1448
|
+
coo_length = 3
|
|
1449
|
+
if len(coo_tuple) != coo_length:
|
|
1450
|
+
raise ValueError(f"Expect coo_tuple have 3 inputs (indices, values, shape), but got {len(coo_tuple)}.")
|
|
1451
|
+
return coo_tuple
|
|
1452
|
+
|
|
1453
|
+
|
|
1454
|
+
def check_csr_tensor_input_length(csr_tuple):
|
|
1455
|
+
"""Check length of csr tensor."""
|
|
1456
|
+
csr_length = 4
|
|
1457
|
+
if len(csr_tuple) != csr_length:
|
|
1458
|
+
raise ValueError(f"Expect csr_tuple have 4 inputs (indptr, indices, values, shape), but got {len(csr_tuple)}.")
|
|
1459
|
+
return csr_tuple
|