mindspore 2.2.14__cp37-cp37m-manylinux1_x86_64.whl → 2.3.0rc2__cp37-cp37m-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -4
- mindspore/_akg/akg/composite/build_module.py +155 -11
- mindspore/_akg/akg/config/repository.json +38 -0
- mindspore/_akg/akg/ms/info_version_adapt.py +29 -0
- mindspore/_akg/akg/tvm/contrib/nvcc.py +4 -1
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +2 -1
- mindspore/_akg/akg/utils/composite_op_helper.py +4 -2
- mindspore/_akg/akg/utils/dump_ascend_meta.py +2 -2
- mindspore/_akg/akg/utils/gen_random.py +14 -8
- mindspore/_akg/akg/utils/op_dsl.py +11 -0
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +18 -8
- mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_checkparam.py +78 -0
- mindspore/_extends/builtin_operations.py +2 -1
- mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
- mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
- mindspore/_extends/parse/__init__.py +18 -14
- mindspore/_extends/parse/compile_config.py +229 -0
- mindspore/_extends/parse/parser.py +155 -59
- mindspore/_extends/parse/resources.py +40 -7
- mindspore/_extends/parse/standard_method.py +127 -206
- mindspore/_extends/remote/kernel_build_server.py +2 -0
- mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/{ops/_op_impl/tbe/atomic_addr_clean.py → _profiler.py} +13 -16
- mindspore/amp.py +24 -18
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost_cell_wrapper.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +7 -3
- mindspore/common/_jit_fallback_utils.py +2 -3
- mindspore/common/_register_for_adapter.py +7 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_stub_tensor.py +7 -1
- mindspore/common/_utils.py +5 -17
- mindspore/common/api.py +145 -50
- mindspore/common/auto_dynamic_shape.py +27 -14
- mindspore/common/dtype.py +9 -6
- mindspore/common/dump.py +5 -4
- mindspore/common/hook_handle.py +51 -4
- mindspore/common/initializer.py +1 -1
- mindspore/common/jit_config.py +33 -13
- mindspore/common/lazy_inline.py +58 -17
- mindspore/common/mindir_util.py +12 -2
- mindspore/common/mutable.py +79 -14
- mindspore/common/parameter.py +24 -4
- mindspore/common/recompute.py +247 -0
- mindspore/common/seed.py +9 -9
- mindspore/common/sparse_tensor.py +251 -18
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +391 -465
- mindspore/communication/__init__.py +3 -3
- mindspore/communication/_comm_helper.py +5 -0
- mindspore/communication/management.py +53 -38
- mindspore/config/op_info.config +22 -54
- mindspore/context.py +176 -55
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +6 -6
- mindspore/dataset/audio/transforms.py +711 -158
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/engine/cache_client.py +2 -2
- mindspore/dataset/engine/datasets.py +72 -38
- mindspore/dataset/engine/datasets_audio.py +14 -14
- mindspore/dataset/engine/datasets_standard_format.py +33 -3
- mindspore/dataset/engine/datasets_text.py +38 -38
- mindspore/dataset/engine/datasets_user_defined.py +7 -7
- mindspore/dataset/engine/datasets_vision.py +75 -71
- mindspore/dataset/engine/offload.py +5 -7
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +408 -121
- mindspore/dataset/text/utils.py +9 -9
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/transforms.py +261 -76
- mindspore/dataset/utils/browse_dataset.py +9 -9
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +5 -5
- mindspore/dataset/vision/transforms.py +2264 -514
- mindspore/dataset/vision/utils.py +40 -9
- mindspore/dataset/vision/validators.py +7 -1
- mindspore/experimental/optim/__init__.py +12 -2
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +35 -34
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +40 -16
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +66 -121
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +15 -8
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +28 -19
- mindspore/hal/__init__.py +34 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/stream.py +339 -0
- mindspore/include/api/data_type.h +2 -2
- mindspore/include/api/dual_abi_helper.h +16 -3
- mindspore/include/api/model.h +1 -3
- mindspore/include/api/status.h +14 -0
- mindspore/include/c_api/model_c.h +173 -0
- mindspore/include/c_api/ms/base/types.h +1 -0
- mindspore/include/c_api/types_c.h +19 -0
- mindspore/include/dataset/execute.h +1 -3
- mindspore/include/mindapi/base/format.h +125 -23
- mindspore/include/mindapi/base/types.h +12 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libmpi_collective.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +2044 -154
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +2044 -33
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/build_tbe_kernel.py +529 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/compiler.py +56 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/custom.py +1109 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/get_file_path.py +36 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +0 -2
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/tbe_topi.py +556 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +0 -2
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +6318 -1760
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_add_custom.h +49 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_decoder_kv_cache.h +59 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_prompt_kv_cache.h +59 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/lib/libcust_opapi.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +52 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +232 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +232 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/add_custom.cpp +81 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/add_custom.py +134 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/decoder_kv_cache.cpp +192 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/decoder_kv_cache.py +134 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/prompt_kv_cache.cpp +274 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/prompt_kv_cache.py +134 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/op_tiling/lib/linux/x86_64/libcust_opmaster_rt2.0.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/op_tiling/liboptiling.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_proto/inc/op_proto.h +39 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_proto/lib/linux/x86_64/libcust_opsproto_rt2.0.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu10.1/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/{libmindspore_ascend.so.1 → libmindspore_ascend.so.2} +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/log.py +2 -2
- mindspore/mindrecord/__init__.py +5 -1
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +25 -0
- mindspore/mindrecord/filewriter.py +74 -56
- mindspore/mindrecord/mindpage.py +40 -6
- mindspore/mindrecord/shardutils.py +3 -2
- mindspore/mindrecord/shardwriter.py +7 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +8 -13
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -15
- mindspore/mindrecord/tools/csv_to_mr.py +4 -9
- mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
- mindspore/mindrecord/tools/mnist_to_mr.py +7 -12
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -6
- mindspore/mint/__init__.py +457 -0
- mindspore/mint/nn/__init__.py +430 -0
- mindspore/mint/nn/functional.py +424 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +186 -0
- mindspore/multiprocessing/__init__.py +72 -0
- mindspore/nn/__init__.py +3 -0
- mindspore/nn/cell.py +131 -174
- mindspore/nn/dynamic_lr.py +2 -2
- mindspore/nn/extend/__init__.py +29 -0
- mindspore/nn/extend/basic.py +140 -0
- mindspore/nn/extend/embedding.py +143 -0
- mindspore/{rewrite/ast_creator_register.py → nn/extend/layer/__init__.py} +9 -19
- mindspore/nn/extend/layer/normalization.py +107 -0
- mindspore/nn/extend/pooling.py +117 -0
- mindspore/nn/generator.py +297 -0
- mindspore/nn/layer/activation.py +79 -90
- mindspore/nn/layer/basic.py +113 -81
- mindspore/nn/layer/channel_shuffle.py +3 -16
- mindspore/nn/layer/container.py +3 -3
- mindspore/nn/layer/conv.py +71 -71
- mindspore/nn/layer/embedding.py +105 -44
- mindspore/nn/layer/image.py +4 -7
- mindspore/nn/layer/normalization.py +52 -66
- mindspore/nn/layer/padding.py +30 -39
- mindspore/nn/layer/pooling.py +13 -9
- mindspore/nn/layer/rnn_cells.py +5 -15
- mindspore/nn/layer/rnns.py +6 -5
- mindspore/nn/layer/thor_layer.py +1 -2
- mindspore/nn/layer/timedistributed.py +1 -1
- mindspore/nn/layer/transformer.py +52 -50
- mindspore/nn/learning_rate_schedule.py +6 -5
- mindspore/nn/loss/loss.py +43 -64
- mindspore/nn/optim/ada_grad.py +4 -2
- mindspore/nn/optim/adadelta.py +3 -1
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +102 -181
- mindspore/nn/optim/adamax.py +4 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +4 -2
- mindspore/nn/optim/ftrl.py +31 -61
- mindspore/nn/optim/lamb.py +5 -3
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +6 -4
- mindspore/nn/optim/momentum.py +13 -25
- mindspore/nn/optim/optimizer.py +6 -3
- mindspore/nn/optim/proximal_ada_grad.py +4 -2
- mindspore/nn/optim/rmsprop.py +9 -3
- mindspore/nn/optim/rprop.py +4 -2
- mindspore/nn/optim/sgd.py +6 -5
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
- mindspore/nn/probability/distribution/beta.py +2 -2
- mindspore/nn/probability/distribution/categorical.py +4 -6
- mindspore/nn/probability/distribution/cauchy.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +2 -2
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +13 -1
- mindspore/nn/wrap/__init__.py +2 -1
- mindspore/nn/wrap/cell_wrapper.py +33 -12
- mindspore/nn/wrap/grad_reducer.py +148 -8
- mindspore/nn/wrap/loss_scale.py +7 -7
- mindspore/numpy/__init__.py +2 -0
- mindspore/numpy/array_creations.py +2 -0
- mindspore/numpy/array_ops.py +1 -5
- mindspore/numpy/fft.py +431 -0
- mindspore/numpy/math_ops.py +54 -60
- mindspore/numpy/utils.py +3 -0
- mindspore/ops/__init__.py +5 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +4 -129
- mindspore/ops/_grad_experimental/grad_comm_ops.py +14 -18
- mindspore/ops/_grad_experimental/grad_math_ops.py +68 -283
- mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
- mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
- mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/__init__.py +0 -1
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +1 -1
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
- mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -3
- mindspore/ops/_op_impl/cpu/adam.py +2 -2
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
- mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
- mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
- mindspore/ops/_vmap/vmap_array_ops.py +137 -101
- mindspore/ops/_vmap/vmap_base.py +8 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +143 -58
- mindspore/ops/_vmap/vmap_image_ops.py +70 -13
- mindspore/ops/_vmap/vmap_math_ops.py +101 -57
- mindspore/ops/_vmap/vmap_nn_ops.py +230 -97
- mindspore/ops/_vmap/vmap_other_ops.py +1 -1
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +205 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +257 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +171 -0
- mindspore/ops/auto_generate/gen_extend_func.py +404 -0
- mindspore/ops/auto_generate/gen_ops_def.py +5653 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +11623 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +359 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +118 -17
- mindspore/ops/composite/math_ops.py +9 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +168 -602
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +24 -133
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
- mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
- mindspore/ops/deprecated.py +14 -3
- mindspore/ops/extend/__init__.py +54 -0
- mindspore/ops/extend/array_func.py +259 -0
- mindspore/ops/extend/math_func.py +76 -0
- mindspore/ops/extend/nn_func.py +384 -0
- mindspore/ops/function/__init__.py +37 -12
- mindspore/ops/function/array_func.py +702 -1867
- mindspore/ops/function/clip_func.py +19 -31
- mindspore/ops/function/debug_func.py +1 -4
- mindspore/ops/function/fft_func.py +31 -0
- mindspore/ops/function/grad/grad_func.py +24 -17
- mindspore/ops/function/image_func.py +27 -21
- mindspore/ops/function/linalg_func.py +35 -68
- mindspore/ops/function/math_func.py +639 -2531
- mindspore/ops/function/nn_func.py +1274 -832
- mindspore/ops/function/other_func.py +4 -5
- mindspore/ops/function/parameter_func.py +5 -93
- mindspore/ops/function/random_func.py +84 -71
- mindspore/ops/function/sparse_unary_func.py +9 -16
- mindspore/ops/function/spectral_func.py +1 -1
- mindspore/ops/function/vmap_func.py +14 -14
- mindspore/ops/functional.py +57 -63
- mindspore/ops/op_info_register.py +16 -43
- mindspore/ops/operations/__init__.py +19 -20
- mindspore/ops/operations/_grad_ops.py +20 -828
- mindspore/ops/operations/_inner_ops.py +180 -288
- mindspore/ops/operations/_scalar_ops.py +5 -480
- mindspore/ops/operations/_sequence_ops.py +6 -36
- mindspore/ops/operations/array_ops.py +83 -2697
- mindspore/ops/operations/comm_ops.py +38 -46
- mindspore/ops/operations/custom_ops.py +14 -96
- mindspore/ops/operations/debug_ops.py +100 -31
- mindspore/ops/operations/image_ops.py +1 -217
- mindspore/ops/operations/inner_ops.py +3 -38
- mindspore/ops/operations/linalg_ops.py +1 -49
- mindspore/{rewrite/ast_transformers → ops/operations/manually_defined}/__init__.py +11 -4
- mindspore/ops/operations/manually_defined/_inner.py +61 -0
- mindspore/ops/operations/manually_defined/ops_def.py +1716 -0
- mindspore/ops/operations/math_ops.py +581 -4629
- mindspore/ops/operations/nn_ops.py +260 -1941
- mindspore/ops/operations/other_ops.py +50 -42
- mindspore/ops/operations/random_ops.py +3 -52
- mindspore/ops/operations/sparse_ops.py +3 -3
- mindspore/ops/primitive.py +196 -96
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +257 -0
- mindspore/ops_generate/arg_handler.py +171 -0
- mindspore/ops_generate/gen_aclnn_implement.py +266 -0
- mindspore/ops_generate/gen_ops.py +1062 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +939 -0
- mindspore/ops_generate/gen_utils.py +188 -0
- mindspore/ops_generate/op_proto.py +138 -0
- mindspore/ops_generate/pyboost_utils.py +349 -0
- mindspore/ops_generate/template.py +238 -0
- mindspore/parallel/__init__.py +6 -4
- mindspore/parallel/_auto_parallel_context.py +52 -2
- mindspore/parallel/_cell_wrapper.py +16 -9
- mindspore/parallel/_cost_model_context.py +1 -1
- mindspore/parallel/_dp_allreduce_fusion.py +159 -159
- mindspore/parallel/_parallel_serialization.py +29 -13
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +19 -7
- mindspore/parallel/_transformer/__init__.py +1 -1
- mindspore/parallel/_transformer/layers.py +1 -1
- mindspore/parallel/_transformer/loss.py +1 -1
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/op_parallel_config.py +1 -1
- mindspore/parallel/_transformer/transformer.py +1 -1
- mindspore/parallel/_utils.py +147 -6
- mindspore/parallel/algo_parameter_config.py +6 -6
- mindspore/parallel/checkpoint_transform.py +180 -24
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +345 -0
- mindspore/parallel/cluster/process_entity/_utils.py +116 -0
- mindspore/parallel/cluster/run.py +139 -0
- mindspore/parallel/mpi/__init__.py +1 -1
- mindspore/parallel/mpi/_mpi_config.py +1 -1
- mindspore/parallel/parameter_broadcast.py +152 -0
- mindspore/parallel/shard.py +99 -2
- mindspore/profiler/common/util.py +20 -0
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
- mindspore/profiler/parser/ascend_analysis/constant.py +66 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +77 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +146 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +109 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +80 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +52 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +116 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +59 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +14 -9
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -1
- mindspore/profiler/parser/ascend_flops_generator.py +20 -4
- mindspore/profiler/parser/ascend_hccl_generator.py +25 -277
- mindspore/profiler/parser/ascend_msprof_exporter.py +112 -132
- mindspore/profiler/parser/ascend_msprof_generator.py +73 -283
- mindspore/profiler/parser/ascend_op_generator.py +92 -42
- mindspore/profiler/parser/ascend_timeline_generator.py +294 -133
- mindspore/profiler/parser/base_timeline_generator.py +6 -0
- mindspore/profiler/parser/framework_parser.py +3 -2
- mindspore/profiler/parser/integrator.py +3 -1
- mindspore/profiler/parser/msadvisor_analyzer.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +16 -1
- mindspore/profiler/profiling.py +305 -167
- mindspore/rewrite/__init__.py +2 -13
- mindspore/rewrite/api/node.py +121 -35
- mindspore/rewrite/api/pattern_engine.py +2 -3
- mindspore/rewrite/api/scoped_value.py +16 -15
- mindspore/rewrite/api/symbol_tree.py +45 -29
- mindspore/rewrite/ast_helpers/__init__.py +3 -6
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
- mindspore/rewrite/common/__init__.py +1 -2
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
- mindspore/rewrite/{namer.py → common/namer.py} +63 -18
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/node/__init__.py +5 -5
- mindspore/rewrite/node/call_function.py +23 -7
- mindspore/rewrite/node/cell_container.py +7 -3
- mindspore/rewrite/node/control_flow.py +53 -28
- mindspore/rewrite/node/node.py +212 -196
- mindspore/rewrite/node/node_manager.py +51 -22
- mindspore/rewrite/node/node_topological_manager.py +3 -23
- mindspore/rewrite/parsers/__init__.py +12 -0
- mindspore/rewrite/parsers/arguments_parser.py +8 -9
- mindspore/rewrite/parsers/assign_parser.py +635 -413
- mindspore/rewrite/parsers/attribute_parser.py +3 -4
- mindspore/rewrite/parsers/class_def_parser.py +107 -144
- mindspore/rewrite/parsers/constant_parser.py +5 -5
- mindspore/rewrite/parsers/container_parser.py +4 -6
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +31 -98
- mindspore/rewrite/parsers/function_def_parser.py +13 -5
- mindspore/rewrite/parsers/if_parser.py +28 -10
- mindspore/rewrite/parsers/module_parser.py +8 -182
- mindspore/rewrite/parsers/parser.py +1 -5
- mindspore/rewrite/parsers/parser_register.py +1 -1
- mindspore/rewrite/parsers/return_parser.py +5 -10
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +704 -185
- mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
- mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
- mindspore/run_check/_check_version.py +6 -14
- mindspore/run_check/run_check.py +1 -1
- mindspore/safeguard/rewrite_obfuscation.py +9 -19
- mindspore/scipy/__init__.py +2 -1
- mindspore/scipy/fft.py +133 -0
- mindspore/scipy/linalg.py +140 -55
- mindspore/scipy/ops.py +15 -71
- mindspore/scipy/ops_grad.py +5 -34
- mindspore/scipy/optimize/line_search.py +2 -2
- mindspore/scipy/optimize/minimize.py +1 -1
- mindspore/train/__init__.py +3 -2
- mindspore/train/_utils.py +178 -4
- mindspore/train/amp.py +167 -245
- mindspore/train/anf_ir_pb2.py +8 -2
- mindspore/train/callback/_backup_and_restore.py +4 -4
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +39 -13
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +14 -8
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
- mindspore/train/callback/_summary_collector.py +7 -7
- mindspore/train/callback/_time_monitor.py +2 -2
- mindspore/train/data_sink.py +1 -1
- mindspore/train/dataset_helper.py +18 -4
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/accuracy.py +7 -7
- mindspore/train/metrics/confusion_matrix.py +8 -6
- mindspore/train/metrics/cosine_similarity.py +6 -4
- mindspore/train/metrics/error.py +2 -2
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/perplexity.py +2 -1
- mindspore/train/metrics/topk.py +2 -2
- mindspore/train/mind_ir_pb2.py +89 -15
- mindspore/train/model.py +24 -22
- mindspore/train/serialization.py +257 -133
- mindspore/train/summary/summary_record.py +51 -28
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/version.py +1 -1
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc2.dist-info}/METADATA +2 -2
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc2.dist-info}/RECORD +534 -1066
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc2.dist-info}/entry_points.txt +1 -0
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
- mindspore/config/super_bar_config.json +0 -544
- mindspore/gen_ops.py +0 -273
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/nn/layer/flash_attention.py +0 -189
- mindspore/ops/_op_impl/cpu/concat.py +0 -39
- mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
- mindspore/ops/_op_impl/tbe/__init__.py +0 -47
- mindspore/ops/_op_impl/tbe/abs.py +0 -38
- mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/acos.py +0 -37
- mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/acosh.py +0 -37
- mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
- mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
- mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
- mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
- mindspore/ops/_op_impl/tbe/add.py +0 -42
- mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/add_n.py +0 -39
- mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
- mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
- mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
- mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
- mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
- mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
- mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
- mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/asin.py +0 -37
- mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/asinh.py +0 -37
- mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/assign.py +0 -79
- mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
- mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
- mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/atan.py +0 -37
- mindspore/ops/_op_impl/tbe/atan2.py +0 -38
- mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/atanh.py +0 -37
- mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
- mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
- mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
- mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
- mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
- mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
- mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
- mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
- mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
- mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
- mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cast.py +0 -55
- mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/cdist.py +0 -38
- mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/ceil.py +0 -37
- mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/celu.py +0 -39
- mindspore/ops/_op_impl/tbe/centralization.py +0 -39
- mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
- mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/concat.py +0 -40
- mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
- mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
- mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
- mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
- mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
- mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/cos.py +0 -37
- mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/cosh.py +0 -37
- mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
- mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cummin.py +0 -41
- mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
- mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
- mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
- mindspore/ops/_op_impl/tbe/diag.py +0 -38
- mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
- mindspore/ops/_op_impl/tbe/dilation.py +0 -40
- mindspore/ops/_op_impl/tbe/div.py +0 -41
- mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
- mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
- mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
- mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
- mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
- mindspore/ops/_op_impl/tbe/elu.py +0 -38
- mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/equal.py +0 -42
- mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/erf.py +0 -37
- mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfc.py +0 -37
- mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
- mindspore/ops/_op_impl/tbe/exp.py +0 -40
- mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
- mindspore/ops/_op_impl/tbe/expm1.py +0 -37
- mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
- mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/fill.py +0 -56
- mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/flatten.py +0 -48
- mindspore/ops/_op_impl/tbe/floor.py +0 -37
- mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
- mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
- mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
- mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
- mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
- mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
- mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
- mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/ger.py +0 -43
- mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/greater.py +0 -43
- mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
- mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
- mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
- mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
- mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
- mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
- mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
- mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/im2col.py +0 -42
- mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
- mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
- mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/inv.py +0 -38
- mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/invert.py +0 -37
- mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/iou.py +0 -38
- mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/is_close.py +0 -40
- mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
- mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
- mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
- mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
- mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
- mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
- mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/lerp.py +0 -38
- mindspore/ops/_op_impl/tbe/less.py +0 -41
- mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/log.py +0 -40
- mindspore/ops/_op_impl/tbe/log1p.py +0 -37
- mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
- mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
- mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
- mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
- mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/matmul.py +0 -53
- mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
- mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
- mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
- mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum.py +0 -39
- mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
- mindspore/ops/_op_impl/tbe/minimum.py +0 -40
- mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mish.py +0 -37
- mindspore/ops/_op_impl/tbe/mod.py +0 -41
- mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/mul.py +0 -37
- mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
- mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
- mindspore/ops/_op_impl/tbe/neg.py +0 -39
- mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
- mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
- mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
- mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
- mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
- mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/pack.py +0 -58
- mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
- mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
- mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/pdist.py +0 -36
- mindspore/ops/_op_impl/tbe/pooling.py +0 -46
- mindspore/ops/_op_impl/tbe/population_count.py +0 -38
- mindspore/ops/_op_impl/tbe/pow.py +0 -41
- mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/prelu.py +0 -37
- mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/range.py +0 -39
- mindspore/ops/_op_impl/tbe/real_div.py +0 -38
- mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
- mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
- mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
- mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
- mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6.py +0 -38
- mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/renorm.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
- mindspore/ops/_op_impl/tbe/rint.py +0 -37
- mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roll.py +0 -42
- mindspore/ops/_op_impl/tbe/round.py +0 -38
- mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
- mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
- mindspore/ops/_op_impl/tbe/select.py +0 -38
- mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/selu.py +0 -39
- mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sgd.py +0 -62
- mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sign.py +0 -38
- mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/sin.py +0 -37
- mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sinh.py +0 -37
- mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/slice.py +0 -58
- mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
- mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax.py +0 -37
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
- mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/softplus.py +0 -37
- mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softsign.py +0 -37
- mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sort.py +0 -38
- mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/split_d.py +0 -38
- mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/split_v.py +0 -39
- mindspore/ops/_op_impl/tbe/splitv.py +0 -39
- mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/square.py +0 -38
- mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
- mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
- mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
- mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
- mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
- mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
- mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
- mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
- mindspore/ops/_op_impl/tbe/sub.py +0 -39
- mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tan.py +0 -38
- mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh.py +0 -37
- mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
- mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
- mindspore/ops/_op_impl/tbe/tile.py +0 -37
- mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
- mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
- mindspore/ops/_op_impl/tbe/transpose.py +0 -60
- mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
- mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
- mindspore/ops/_op_impl/tbe/trunc.py +0 -39
- mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/unpack.py +0 -38
- mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
- mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
- mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
- mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
- mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
- mindspore/ops/_tracefunc.py +0 -241
- mindspore/ops/arg_dtype_cast.py +0 -54
- mindspore/rewrite/api/tree_node_helper.py +0 -60
- mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
- mindspore/rewrite/namespace.py +0 -53
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc2.dist-info}/WHEEL +0 -0
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc2.dist-info}/top_level.txt +0 -0
mindspore/scipy/linalg.py
CHANGED
|
@@ -14,8 +14,7 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Linear algebra submodule"""
|
|
16
16
|
from __future__ import absolute_import
|
|
17
|
-
from .ops import LU
|
|
18
|
-
from .ops import SolveTriangular
|
|
17
|
+
from .ops import LU, SolveTriangular
|
|
19
18
|
from .utils import _nd_transpose, _value_check, _type_check, _dtype_check, _mstype_check, _square_check, _solve_check
|
|
20
19
|
from .utils_const import _raise_value_error
|
|
21
20
|
from .. import numpy as mnp
|
|
@@ -26,7 +25,74 @@ from ..ops.operations.linalg_ops import Eigh
|
|
|
26
25
|
from ..ops import functional as F
|
|
27
26
|
from ..ops import operations as P
|
|
28
27
|
|
|
29
|
-
__all__ = ['block_diag', 'inv', 'cho_factor', 'cholesky',
|
|
28
|
+
__all__ = ['block_diag', 'inv', 'cho_factor', 'cholesky',
|
|
29
|
+
'cho_solve', 'eigh', 'lu_factor', 'lu', 'solve_triangular']
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False, overwrite_b=False, debug=None, check_finite=True):
|
|
33
|
+
"""
|
|
34
|
+
Solve the linear system :math:`a x = b` for `x`, Assuming `a` is a triangular matrix.
|
|
35
|
+
|
|
36
|
+
Note:
|
|
37
|
+
- `solve_triangular` is currently only used in `mindscience` scientific computing scenarios and
|
|
38
|
+
dose not support other usage scenarios.
|
|
39
|
+
- `solve_triangular` is not supported on Windows platform yet.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
a (Tensor): A triangular matrix of shape :math:`(*, M, M)` where :math:`*` is zero or more batch dimensions.
|
|
43
|
+
b (Tensor): A Tensor of shape :math:`(*, M)` or :math:`(*, M, N)`. Right-hand side matrix in :math:`a x = b`.
|
|
44
|
+
trans (Union[int, str], optional): Type of system to solve. Default: ``0``.
|
|
45
|
+
|
|
46
|
+
======== =========
|
|
47
|
+
trans system
|
|
48
|
+
======== =========
|
|
49
|
+
0 or 'N' a x = b
|
|
50
|
+
1 or 'T' a^T x = b
|
|
51
|
+
2 or 'C' a^H x = b
|
|
52
|
+
======== =========
|
|
53
|
+
|
|
54
|
+
lower (bool, optional): Use only data contained in the lower triangle of `a`. Default: ``False``.
|
|
55
|
+
unit_diagonal (bool, optional): If ``True``, diagonal elements of :math:`a` are assumed to be 1 and
|
|
56
|
+
will not be referenced. Default: ``False``.
|
|
57
|
+
overwrite_b (bool, optional): Not implemented now. Default: ``False``.
|
|
58
|
+
debug (Any, optional): Not implemented now. Default: ``None``.
|
|
59
|
+
check_finite (bool, optional): Not implemented now. Default: ``True``.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
Tensor of shape :math:`(*, M)` or :math:`(*, M, N)`,
|
|
63
|
+
which is the solution to the system :math:`a x = b`.
|
|
64
|
+
Shape of :math:`x` matches :math:`b`.
|
|
65
|
+
|
|
66
|
+
Raises:
|
|
67
|
+
ValueError: If `a` is less than 2 dimension.
|
|
68
|
+
ValueError: if `a` is not square matrix.
|
|
69
|
+
TypeError: If dtype of `a` and `b` are not the same.
|
|
70
|
+
ValueError: If the shape of `a` and `b` are not matched.
|
|
71
|
+
ValueError: If `trans` is not in set {0, 1, 2, 'N', 'T', 'C'}.
|
|
72
|
+
|
|
73
|
+
Supported Platforms:
|
|
74
|
+
``Ascend`` ``CPU``
|
|
75
|
+
|
|
76
|
+
Examples:
|
|
77
|
+
>>> import numpy as onp
|
|
78
|
+
>>> import mindspore
|
|
79
|
+
>>> from mindspore import Tensor
|
|
80
|
+
>>> from mindspore.scipy.linalg import solve_triangular
|
|
81
|
+
>>> a = Tensor(onp.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]], onp.float32))
|
|
82
|
+
>>> b = Tensor(onp.array([3, 1, 3, 4], onp.float32))
|
|
83
|
+
>>> x = solve_triangular(a, b, lower=True, unit_diagonal=False, trans='N')
|
|
84
|
+
>>> print(x)
|
|
85
|
+
[ 1. -1. 2. 2.]
|
|
86
|
+
>>> print(a @ x) # Check the result
|
|
87
|
+
[3. 1. 3. 4.]
|
|
88
|
+
"""
|
|
89
|
+
trans_str_to_int = {'N': 0, 'T': 1, 'C': 2}
|
|
90
|
+
if isinstance(trans, str):
|
|
91
|
+
trans = trans_str_to_int.get(trans)
|
|
92
|
+
if trans is None:
|
|
93
|
+
_raise_value_error(
|
|
94
|
+
"For SolveTriangular, Augment[trans] must be one of [1, 2, 3,'N', 'T', 'C'].")
|
|
95
|
+
return ops.auto_generate.solve_triangular(a, b, trans=trans, lower=lower, unit_diagonal=unit_diagonal)
|
|
30
96
|
|
|
31
97
|
|
|
32
98
|
def block_diag(*arrs):
|
|
@@ -47,7 +113,7 @@ def block_diag(*arrs):
|
|
|
47
113
|
|
|
48
114
|
Args:
|
|
49
115
|
arrs (list): up to 2-D Input Tensors.
|
|
50
|
-
|
|
116
|
+
One or more Tensors, the dimension of Tensors should be 0-D, 1-D or 2-D.
|
|
51
117
|
|
|
52
118
|
Returns:
|
|
53
119
|
Tensor with `A`, `B`, `C`, ... on the diagonal which has the same dtype as `A`.
|
|
@@ -102,8 +168,8 @@ def inv(a, overwrite_a=False, check_finite=True):
|
|
|
102
168
|
|
|
103
169
|
Note:
|
|
104
170
|
- `inv` is not supported on Windows platform yet.
|
|
105
|
-
- Only
|
|
106
|
-
|
|
171
|
+
- Only float32, float64, int32, int64 are supported Tensor dtypes.
|
|
172
|
+
- If Tensor with dtype int32 or int64 is passed, it will be cast to mstype.float64.
|
|
107
173
|
|
|
108
174
|
Args:
|
|
109
175
|
a (Tensor): Square matrix to be inverted.
|
|
@@ -152,14 +218,21 @@ def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
|
|
|
152
218
|
"""
|
|
153
219
|
Compute the cholesky decomposition of a matrix, to use in :func:`mindspore.scipy.linalg.cho_solve`.
|
|
154
220
|
|
|
155
|
-
Returns a matrix
|
|
156
|
-
|
|
221
|
+
Returns the cholesky decomposition of a Hermitian positive-definite matrix A. Base on the value of `lower`,
|
|
222
|
+
perform the following decomposition:
|
|
223
|
+
|
|
224
|
+
- when `lower` is True: :math:`A = L L^*`
|
|
225
|
+
- when `lower` is False: :math:`A = U^* U`
|
|
226
|
+
|
|
227
|
+
:math:`L^*` is a conjugate transpose matrix of :math:`L`.
|
|
228
|
+
:math:`U^*` is a conjugate transpose matrix of :math:`U`.
|
|
229
|
+
|
|
157
230
|
The return value can be directly used as the first parameter to :func:`mindspore.scipy.linalg.cho_solve`.
|
|
158
231
|
|
|
159
232
|
Note:
|
|
160
233
|
- `cho_factor` is not supported on Windows platform yet.
|
|
161
|
-
- Only
|
|
162
|
-
|
|
234
|
+
- Only float32, float64, int32, int64 are supported Tensor dtypes.
|
|
235
|
+
- If Tensor with dtype int32 or int64 is passed, it will be cast to mstype.float64.
|
|
163
236
|
|
|
164
237
|
.. warning::
|
|
165
238
|
The returned matrix also contains random data in the entries not
|
|
@@ -167,7 +240,7 @@ def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
|
|
|
167
240
|
entries, use the function `cholesky` instead.
|
|
168
241
|
|
|
169
242
|
Args:
|
|
170
|
-
a (Tensor): square Matrix of (M,
|
|
243
|
+
a (Tensor): square Matrix of :math:`(M,M)` to be decomposed.
|
|
171
244
|
lower (bool, optional): Whether to compute the upper or lower triangular cholesky factorization.
|
|
172
245
|
Default: ``False`` .
|
|
173
246
|
overwrite_a(bool, optional): Whether to overwrite data in a (may improve performance). Default: ``False`` .
|
|
@@ -221,16 +294,22 @@ def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
|
|
|
221
294
|
"""
|
|
222
295
|
Compute the cholesky decomposition of a matrix.
|
|
223
296
|
|
|
224
|
-
Returns the cholesky decomposition
|
|
225
|
-
|
|
297
|
+
Returns the cholesky decomposition of a Hermitian positive-definite matrix A. Base on the value of `lower`,
|
|
298
|
+
perform the following decomposition:
|
|
299
|
+
|
|
300
|
+
- when `lower` is True: :math:`A = L L^*`
|
|
301
|
+
- when `lower` is False: :math:`A = U^* U`
|
|
302
|
+
|
|
303
|
+
:math:`L^*` is a conjugate transpose matrix of L.
|
|
304
|
+
:math:`U^*` is a conjugate transpose matrix of U.
|
|
226
305
|
|
|
227
306
|
Note:
|
|
228
307
|
- `cholesky` is not supported on Windows platform yet.
|
|
229
|
-
- Only
|
|
230
|
-
|
|
308
|
+
- Only float32, float64, int32, int64 are supported Tensor dtypes.
|
|
309
|
+
- If Tensor with dtype int32 or int64 is passed, it will be cast to mstype.float64.
|
|
231
310
|
|
|
232
311
|
Args:
|
|
233
|
-
a (Tensor): square Matrix of (M, M) to be decomposed.
|
|
312
|
+
a (Tensor): square Matrix of :math:`(M, M)` to be decomposed.
|
|
234
313
|
lower (bool, optional): Whether to compute the upper- or lower-triangular cholesky
|
|
235
314
|
factorization. Default: ``False`` .
|
|
236
315
|
overwrite_a (bool, optional): Whether to overwrite data in `a` (may improve performance). Default: ``False`` .
|
|
@@ -278,15 +357,15 @@ def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
|
|
|
278
357
|
|
|
279
358
|
def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
|
|
280
359
|
"""
|
|
281
|
-
Given the cholesky factorization of
|
|
360
|
+
Given the cholesky factorization of :math:`A`, solve the linear equation.
|
|
282
361
|
|
|
283
362
|
.. math::
|
|
284
|
-
|
|
363
|
+
A x = b
|
|
285
364
|
|
|
286
365
|
Note:
|
|
287
366
|
- `cho_solve` is not supported on Windows platform yet.
|
|
288
|
-
- Only
|
|
289
|
-
|
|
367
|
+
- Only float32, float64, int32, int64 are supported Tensor dtypes.
|
|
368
|
+
- If Tensor with dtype int32 or int64 is passed, it will be cast to mstype.float64.
|
|
290
369
|
|
|
291
370
|
Args:
|
|
292
371
|
c_and_lower ((Tensor, bool)): cholesky factorization of :math:`a`,
|
|
@@ -299,7 +378,7 @@ def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
|
|
|
299
378
|
(crashes, non-termination) if the inputs do contain infinities or NaNs. Default: ``True``.
|
|
300
379
|
|
|
301
380
|
Returns:
|
|
302
|
-
Tensor, the solution to the system
|
|
381
|
+
Tensor, the solution to the system :math:`A x = b`.
|
|
303
382
|
|
|
304
383
|
Supported Platforms:
|
|
305
384
|
``GPU`` ``CPU``
|
|
@@ -347,7 +426,9 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
|
|
|
347
426
|
|
|
348
427
|
Find eigenvalues Tensor `w` and optionally eigenvectors Tensor `v` of Tensor `a`,
|
|
349
428
|
where `b` is positive definite such that for every eigenvalue `λ` (i-th entry of w) and
|
|
350
|
-
its eigenvector `vi` (i-th column of `v`) satisfies
|
|
429
|
+
its eigenvector `vi` (i-th column of `v`) satisfies:
|
|
430
|
+
|
|
431
|
+
.. code-block::
|
|
351
432
|
|
|
352
433
|
a @ vi = λ * b @ vi
|
|
353
434
|
vi.conj().T @ a @ vi = λ
|
|
@@ -357,8 +438,8 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
|
|
|
357
438
|
|
|
358
439
|
Note:
|
|
359
440
|
- `eigh` is not supported on Windows platform yet.
|
|
360
|
-
- Only
|
|
361
|
-
|
|
441
|
+
- Only float32, float64, int32, int64 are supported Tensor dtypes.
|
|
442
|
+
- If Tensor with dtype int32 or int64 is passed, it will be cast to mstype.float64.
|
|
362
443
|
|
|
363
444
|
Args:
|
|
364
445
|
a (Tensor): A :math:`(M, M)` complex Hermitian or real symmetric matrix whose eigenvalues and
|
|
@@ -369,25 +450,27 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
|
|
|
369
450
|
triangle of `a` and, if applicable, `b`. Default: ``True``.
|
|
370
451
|
eigvals_only (bool, optional): Whether to calculate only eigenvalues and no eigenvectors.
|
|
371
452
|
Default: ``False`` .
|
|
453
|
+
overwrite_a (bool, optional): Whether to overwrite data in `a` (may improve performance). Default: ``False`` .
|
|
454
|
+
overwrite_b (bool, optional): Whether to overwrite data in `b` (may improve performance). Default: ``False`` .
|
|
455
|
+
turbo (bool, optional): use divide and conquer algorithm (faster but expensive in memory, only
|
|
456
|
+
for generalized eigenvalue problem and if full set of eigenvalues are requested.).
|
|
457
|
+
Has no significant effect if eigenvectors are not requested. Default: ``True`` .
|
|
458
|
+
eigvals (tuple, optional): Indexes of the smallest and largest (in ascending order) eigenvalues
|
|
459
|
+
and corresponding eigenvectors to be returned: :math:`0 <= lo <= hi <= M-1`. If omitted, all eigenvalues
|
|
460
|
+
and eigenvectors are returned. Default: ``None`` .
|
|
372
461
|
type (int, optional): For the generalized problems, this keyword specifies the problem type
|
|
373
|
-
to be solved for `w` and `v` (only takes 1, 2, 3 as possible inputs)
|
|
462
|
+
to be solved for `w` and `v` (only takes 1, 2, 3 as possible inputs):
|
|
463
|
+
|
|
464
|
+
.. code-block::
|
|
374
465
|
|
|
375
466
|
1 => a @ v = w @ b @ v
|
|
376
467
|
2 => a @ b @ v = w @ v
|
|
377
468
|
3 => b @ a @ v = w @ v
|
|
378
469
|
|
|
379
470
|
This keyword is ignored for standard problems. Default: ``1`` .
|
|
380
|
-
overwrite_a (bool, optional): Whether to overwrite data in `a` (may improve performance). Default: ``False`` .
|
|
381
|
-
overwrite_b (bool, optional): Whether to overwrite data in `b` (may improve performance). Default: ``False`` .
|
|
382
471
|
check_finite (bool, optional): Whether to check that the input matrices contain only finite numbers.
|
|
383
472
|
Disabling may give a performance gain, but may result in problems (crashes, non-termination)
|
|
384
473
|
if the inputs do contain infinities or NaNs. Default: ``True`` .
|
|
385
|
-
turbo (bool, optional): use divide and conquer algorithm (faster but expensive in memory, only
|
|
386
|
-
for generalized eigenvalue problem and if full set of eigenvalues are requested.).
|
|
387
|
-
Has no significant effect if eigenvectors are not requested. Default: ``True`` .
|
|
388
|
-
eigvals (tuple, optional): Indexes of the smallest and largest (in ascending order) eigenvalues
|
|
389
|
-
and corresponding eigenvectors to be returned: :math:`0 <= lo <= hi <= M-1`. If omitted, all eigenvalues
|
|
390
|
-
and eigenvectors are returned. Default: ``None`` .
|
|
391
474
|
|
|
392
475
|
Returns:
|
|
393
476
|
- Tensor with shape :math:`(N,)`, the :math:`N (1<=N<=M)` selected eigenvalues, in ascending order,
|
|
@@ -406,7 +489,7 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
|
|
|
406
489
|
TypeError: If `overwrite_b` is not bool.
|
|
407
490
|
TypeError: If `turbo` is not bool.
|
|
408
491
|
TypeError: If `check_finite` is not bool.
|
|
409
|
-
ValueError: If `a` is not square matrix.
|
|
492
|
+
ValueError: If `a` is not 2D square matrix.
|
|
410
493
|
ValueError: If `b` is not None.
|
|
411
494
|
ValueError: If `eigvals` is not None.
|
|
412
495
|
|
|
@@ -473,19 +556,19 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
|
|
|
473
556
|
The decomposition is:
|
|
474
557
|
|
|
475
558
|
.. math::
|
|
476
|
-
a =
|
|
559
|
+
a = P L U
|
|
477
560
|
|
|
478
|
-
where :math:`
|
|
479
|
-
and :math:`
|
|
561
|
+
where :math:`P` is a permutation matrix, :math:`L` lower triangular with unit diagonal elements,
|
|
562
|
+
and :math:`U` upper triangular.
|
|
480
563
|
|
|
481
564
|
Note:
|
|
482
565
|
- `lu_factor` is not supported on Windows platform yet.
|
|
483
|
-
- Only
|
|
484
|
-
|
|
566
|
+
- Only float32, float64, int32, int64 are supported Tensor dtypes.
|
|
567
|
+
- If Tensor with dtype int32 or int64 is passed, it will be cast to mstype.float64.
|
|
485
568
|
|
|
486
569
|
Args:
|
|
487
|
-
a (Tensor): square matrix of :math:`(M, M)` to decompose. Note that if the input tensor is not a
|
|
488
|
-
then it will be cast to
|
|
570
|
+
a (Tensor): square matrix of :math:`(M, M)` to decompose. Note that if the input tensor is not a float,
|
|
571
|
+
then it will be cast to mstype.float32.
|
|
489
572
|
overwrite_a (bool, optional): Whether to overwrite data in :math:`a` (may increase performance).
|
|
490
573
|
Default: ``False`` .
|
|
491
574
|
check_finite (bool, optional): Whether to check that the input matrix contains only finite numbers.
|
|
@@ -493,14 +576,14 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
|
|
|
493
576
|
(crashes, non-termination) if the inputs do contain infinities or NaNs. Default: ``True`` .
|
|
494
577
|
|
|
495
578
|
Returns:
|
|
496
|
-
- Tensor, a square matrix of :math:`(
|
|
579
|
+
- Tensor, a square matrix of :math:`(M, M)` containing `U` in its upper triangle, and `L` in its lower triangle.
|
|
497
580
|
The unit diagonal elements of `L` are not stored.
|
|
498
581
|
|
|
499
|
-
- Tensor, :math:`(
|
|
582
|
+
- Tensor, :math:`(M,)` pivot indices representing the permutation matrix `P`:
|
|
500
583
|
the i-th element value j in the indices indicates that row i of matrix was interchanged with row j.
|
|
501
584
|
|
|
502
585
|
Raises:
|
|
503
|
-
ValueError: If :math:`a` is not square.
|
|
586
|
+
ValueError: If :math:`a` is not 2D square.
|
|
504
587
|
|
|
505
588
|
Supported Platforms:
|
|
506
589
|
``GPU`` ``CPU``
|
|
@@ -540,19 +623,19 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
|
|
|
540
623
|
The decomposition is:
|
|
541
624
|
|
|
542
625
|
.. math::
|
|
543
|
-
|
|
626
|
+
A = P L U
|
|
544
627
|
|
|
545
628
|
where :math:`P` is a permutation matrix, :math:`L` lower triangular with unit
|
|
546
629
|
diagonal elements, and :math:`U` upper triangular.
|
|
547
630
|
|
|
548
631
|
Note:
|
|
549
632
|
- `lu` is not supported on Windows platform yet.
|
|
550
|
-
- Only
|
|
551
|
-
|
|
633
|
+
- Only float32, float64, int32, int64 are supported Tensor dtypes.
|
|
634
|
+
- If Tensor with dtype int32 or int64 is passed, it will be cast to mstype.float64.
|
|
552
635
|
|
|
553
636
|
Args:
|
|
554
|
-
a (Tensor): a :math:`(M, N)` matrix to decompose. Note that if the input tensor is not a
|
|
555
|
-
then it will be cast to
|
|
637
|
+
a (Tensor): a :math:`(M, N)` matrix to decompose. Note that if the input tensor is not a float,
|
|
638
|
+
then it will be cast to mstype.float32.
|
|
556
639
|
permute_l (bool, optional): Perform the multiplication :math:`P L` (Default: do not permute).
|
|
557
640
|
Default: ``False`` .
|
|
558
641
|
overwrite_a (bool, optional): Whether to overwrite data in :math:`a` (may improve performance).
|
|
@@ -624,12 +707,12 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
|
|
|
624
707
|
|
|
625
708
|
|
|
626
709
|
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
|
|
627
|
-
"""Solve an equation system,
|
|
710
|
+
"""Solve an equation system, A x = B, given the LU factorization of A
|
|
628
711
|
|
|
629
712
|
Note:
|
|
630
713
|
- `lu_solve` is not supported on Windows platform yet.
|
|
631
|
-
- Only
|
|
632
|
-
|
|
714
|
+
- Only float32, float64, int32, int64 are supported Tensor dtypes.
|
|
715
|
+
- If Tensor with dtype int32 or int64 is passed, it will be cast to mstype.float64.
|
|
633
716
|
|
|
634
717
|
Args:
|
|
635
718
|
lu_and_piv (Tensor, Tensor): Factorization of the coefficient matrix a, as given by lu_factor
|
|
@@ -717,7 +800,9 @@ def det(a, overwrite_a=False, check_finite=True):
|
|
|
717
800
|
The determinant of a square matrix is a value derived arithmetically
|
|
718
801
|
from the coefficients of the matrix.
|
|
719
802
|
|
|
720
|
-
The determinant for a 3x3 matrix, for example, is computed as follows
|
|
803
|
+
The determinant for a 3x3 matrix, for example, is computed as follows:
|
|
804
|
+
|
|
805
|
+
.. code-block::
|
|
721
806
|
|
|
722
807
|
a b c
|
|
723
808
|
d e f = A
|
|
@@ -727,8 +812,8 @@ def det(a, overwrite_a=False, check_finite=True):
|
|
|
727
812
|
|
|
728
813
|
Note:
|
|
729
814
|
- `det` is not supported on Windows platform yet.
|
|
730
|
-
- Only
|
|
731
|
-
|
|
815
|
+
- Only float32, float64, int32, int64 are supported Tensor dtypes.
|
|
816
|
+
- If Tensor with dtype int32 or int64 is passed, it will be cast to mstype.float64.
|
|
732
817
|
|
|
733
818
|
Args:
|
|
734
819
|
a (Tensor): A square matrix to compute. Note that if the input tensor is not a `float`,
|
mindspore/scipy/ops.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2021-
|
|
1
|
+
# Copyright 2021-2024 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -15,77 +15,19 @@
|
|
|
15
15
|
"""Operators for scipy submodule"""
|
|
16
16
|
from mindspore import _checkparam as validator
|
|
17
17
|
from ..ops import PrimitiveWithInfer, prim_attr_register, Primitive
|
|
18
|
+
from ..ops.auto_generate import solve_triangular
|
|
18
19
|
from ..common import dtype as mstype
|
|
19
20
|
|
|
20
21
|
|
|
21
|
-
class SolveTriangular(
|
|
22
|
-
"""
|
|
23
|
-
Solve the equation `a x = b` for `x`, assuming a is a triangular matrix.
|
|
24
|
-
|
|
25
|
-
Args:
|
|
26
|
-
a (Tensor): A triangular matrix of shape :math:`(..., N, N)`.
|
|
27
|
-
b (Tensor): A Tensor of shape :math:`(M,)` or :math:`(..., N, M)`.
|
|
28
|
-
Right-hand side matrix in :math:`a x = b`.
|
|
29
|
-
lower (bool, optional): Use only data contained in the lower triangle of `a`.
|
|
30
|
-
Default is to use upper triangle.
|
|
31
|
-
trans (0, 1, 2, 'N', 'T', 'C', optional):
|
|
32
|
-
Type of system to solve:
|
|
33
|
-
trans: system:
|
|
34
|
-
0 or 'N' a x = b
|
|
35
|
-
1 or 'T' a^T x = b
|
|
36
|
-
2 or 'C' a^H x = b
|
|
37
|
-
unit_diagonal (bool, optional): If True, diagonal elements of :math:`a` are assumed to be 1 and
|
|
38
|
-
will not be referenced.
|
|
39
|
-
overwrite_b (bool, optional): Allow overwriting data in :math:`b` (may enhance performance)
|
|
40
|
-
check_finite (bool, optional): Whether to check that the input matrices contain only finite numbers.
|
|
41
|
-
Disabling may give a performance gain, but may result in problems
|
|
42
|
-
(crashes, non-termination) if the inputs do contain infinities or NaNs.
|
|
43
|
-
|
|
44
|
-
Returns:
|
|
45
|
-
Tensor of shape :math:`(..., M,)` or :math:`(..., M, N)`,
|
|
46
|
-
which is the solution to the system :math:`a x = b`.
|
|
47
|
-
Shape of :math:`x` matches :math:`b`.
|
|
48
|
-
|
|
49
|
-
Raises:
|
|
50
|
-
LinAlgError: If :math:`a` is singular
|
|
51
|
-
|
|
52
|
-
Supported Platforms:
|
|
53
|
-
``GPU`` ``CPU``
|
|
54
|
-
|
|
55
|
-
Examples:
|
|
56
|
-
Solve the lower triangular system :math:`a x = b`, where:
|
|
57
|
-
|
|
58
|
-
[3 0 0 0] [4]
|
|
59
|
-
a = [2 1 0 0] b = [2]
|
|
60
|
-
[1 0 1 0] [4]
|
|
61
|
-
[1 1 1 1] [2]
|
|
62
|
-
|
|
63
|
-
>>> import numpy as onp
|
|
64
|
-
>>> from mindspore import Tensor
|
|
65
|
-
>>> import mindspore.numpy as mnp
|
|
66
|
-
>>> from mindspore.scipy.ops import SolveTriangular
|
|
67
|
-
>>> a = Tensor(onp.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]], onp.float64))
|
|
68
|
-
>>> b = Tensor(onp.array([4, 2, 4, 2], onp.float64))
|
|
69
|
-
>>> solve_triangular = SolveTriangular(lower=True, unit_diagonal=False, trans='N')
|
|
70
|
-
>>> x = solve_triangular(a, b)
|
|
71
|
-
>>> print(x)
|
|
72
|
-
[ 1.33333333 -0.66666667 2.66666667 -1.33333333]
|
|
73
|
-
>>> print(mnp.dot(a, x)) # Check the result
|
|
74
|
-
[4. 2. 4. 2.]
|
|
75
|
-
"""
|
|
76
|
-
|
|
77
|
-
@prim_attr_register
|
|
22
|
+
class SolveTriangular():
|
|
78
23
|
def __init__(self, lower: bool = False, unit_diagonal: bool = False, trans: str = 'N'):
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
self.unit_diagonal = validator.check_value_type(
|
|
84
|
-
"unit_diagonal", unit_diagonal, [bool], self.name)
|
|
85
|
-
self.trans = validator.check_value_type(
|
|
86
|
-
"trans", trans, [str], self.name)
|
|
24
|
+
self.lower = lower
|
|
25
|
+
self.unit_diagonal = unit_diagonal
|
|
26
|
+
trans_str_to_int = {'N': 0, 'T': 1, 'C': 2}
|
|
27
|
+
self.trans = trans_str_to_int.get(trans)
|
|
87
28
|
|
|
88
|
-
|
|
29
|
+
def __call__(self, a, b):
|
|
30
|
+
return solve_triangular(a, b, self.trans, self.lower, self.unit_diagonal)
|
|
89
31
|
|
|
90
32
|
|
|
91
33
|
class Eig(PrimitiveWithInfer):
|
|
@@ -97,9 +39,8 @@ class Eig(PrimitiveWithInfer):
|
|
|
97
39
|
@prim_attr_register
|
|
98
40
|
def __init__(self, compute_v=True):
|
|
99
41
|
super().__init__(name="Eig")
|
|
100
|
-
self.init_prim_io_names(inputs=['a'], outputs=['w', 'v'])
|
|
101
42
|
self.compute_v = validator.check_value_type("compute_v", compute_v, [bool], self.name)
|
|
102
|
-
self.
|
|
43
|
+
self._set_prim_arg("compute_v", compute_v)
|
|
103
44
|
self.io_table = {
|
|
104
45
|
mstype.TensorType(mstype.float32): mstype.complex64,
|
|
105
46
|
mstype.TensorType(mstype.complex64): mstype.complex64,
|
|
@@ -129,6 +70,9 @@ class Eig(PrimitiveWithInfer):
|
|
|
129
70
|
}
|
|
130
71
|
return output
|
|
131
72
|
|
|
73
|
+
def __call__(self, a):
|
|
74
|
+
return super().__call__(a, self.compute_v)
|
|
75
|
+
|
|
132
76
|
|
|
133
77
|
class LU(PrimitiveWithInfer):
|
|
134
78
|
"""
|
|
@@ -182,7 +126,7 @@ class LinearSumAssignment(Primitive):
|
|
|
182
126
|
- **col_idx** (Tensor) - Column indices of the problem. If `dimension_limit` is given, -1 would be padded at
|
|
183
127
|
the end. The shape is :math:`(N, )` , where :math:`N` is the minimum value of `cost_matrix` dimension.
|
|
184
128
|
|
|
185
|
-
|
|
129
|
+
Raises:
|
|
186
130
|
TypeError: If the data type of `cost_matrix` is not the type in [float16, float32, float64,
|
|
187
131
|
int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool]
|
|
188
132
|
TypeError: If the type of `maximize` is not bool.
|
|
@@ -216,4 +160,4 @@ class LinearSumAssignment(Primitive):
|
|
|
216
160
|
self.init_prim_io_names(inputs=['cost_matrix', 'dimension_limit', 'maximize'], outputs=['row_ind', 'col_ind'])
|
|
217
161
|
|
|
218
162
|
# pylint: disable=C0413,W0611
|
|
219
|
-
from .ops_grad import get_bprpo_eigh
|
|
163
|
+
from .ops_grad import get_bprpo_eigh
|
mindspore/scipy/ops_grad.py
CHANGED
|
@@ -14,15 +14,17 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Grad implementation of operators for scipy submodule"""
|
|
16
16
|
from .. import numpy as mnp
|
|
17
|
-
from .ops import Eig
|
|
17
|
+
from .ops import Eig
|
|
18
18
|
from .utils_const import _raise_type_error
|
|
19
19
|
from .ops_wrapper import matrix_set_diag
|
|
20
20
|
from ..ops import operations as P
|
|
21
21
|
from ..ops import functional as F
|
|
22
22
|
from ..ops.operations.linalg_ops import Eigh
|
|
23
23
|
from ..ops._grad_experimental.grad_base import bprop_getters
|
|
24
|
+
from ..ops.composite import multitype_ops as C
|
|
24
25
|
from ..common import dtype as mstype
|
|
25
26
|
|
|
27
|
+
|
|
26
28
|
_matmul = P.MatMul(False, False)
|
|
27
29
|
_real = P.Real()
|
|
28
30
|
_conj = P.Conj()
|
|
@@ -59,9 +61,7 @@ def _batch_eyes(a):
|
|
|
59
61
|
@bprop_getters.register(Eig)
|
|
60
62
|
def get_bprpo_eig(self):
|
|
61
63
|
"""Grad definition for `Eig` operation."""
|
|
62
|
-
is_compute_v
|
|
63
|
-
|
|
64
|
-
def bprop(a, out, dout):
|
|
64
|
+
def bprop(a, is_compute_v, out, dout):
|
|
65
65
|
w, v, grad_w, grad_v = out[0], out[1], dout[0], dout[1]
|
|
66
66
|
if not is_compute_v:
|
|
67
67
|
gw_vh = F.expand_dims(grad_w, -1) * _adjoint(v)
|
|
@@ -74,7 +74,7 @@ def get_bprpo_eig(self):
|
|
|
74
74
|
f = _compute_f(w)
|
|
75
75
|
grad_a = _diag(grad_w) + f * vh_gv
|
|
76
76
|
grad_a = _matrix_solve(vh, _matmul(grad_a, vh)) # not support
|
|
77
|
-
return (grad_a,)
|
|
77
|
+
return (grad_a, C.zeros_like(is_compute_v))
|
|
78
78
|
|
|
79
79
|
return bprop
|
|
80
80
|
|
|
@@ -113,32 +113,3 @@ def get_bprpo_eigh(self):
|
|
|
113
113
|
return (grad_a,)
|
|
114
114
|
|
|
115
115
|
return bprop
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
@bprop_getters.register(SolveTriangular)
|
|
119
|
-
def get_bprpo_trsm(self):
|
|
120
|
-
"""Grad definition for `SolveTriangular` operation.
|
|
121
|
-
Appendix(see trsm) from Matthias Seeger, et al. 'Auto-Differentiating Linear Algebra', 2017, pg. 28-29
|
|
122
|
-
"""
|
|
123
|
-
is_lower = self.lower
|
|
124
|
-
is_unit_diagonal = self.unit_diagonal
|
|
125
|
-
lower = int(is_lower)
|
|
126
|
-
bp_trans = ("N" if self.trans in ["T", "C"] else "T")
|
|
127
|
-
solve_triangular = SolveTriangular(is_lower, is_unit_diagonal, bp_trans)
|
|
128
|
-
|
|
129
|
-
def bprop(a, b, out, dout):
|
|
130
|
-
row_size = F.shape(a)[-2]
|
|
131
|
-
grad_b = solve_triangular(a, dout)
|
|
132
|
-
grad_b_align = F.reshape(grad_b, (row_size, -1))
|
|
133
|
-
x_align = F.reshape(out, (row_size, -1))
|
|
134
|
-
if bp_trans in ["T", "C"]:
|
|
135
|
-
grad_a = _matmul(grad_b_align, _adjoint(x_align))
|
|
136
|
-
else:
|
|
137
|
-
grad_a = _matmul(x_align, _adjoint(grad_b_align))
|
|
138
|
-
|
|
139
|
-
grad_a = -1 * F.matrix_band_part(grad_a, 0 - lower, lower - 1)
|
|
140
|
-
if is_unit_diagonal:
|
|
141
|
-
grad_a = matrix_set_diag(grad_a, F.fill(grad_a.dtype, (row_size,), 0))
|
|
142
|
-
return grad_a, grad_b
|
|
143
|
-
|
|
144
|
-
return bprop
|
|
@@ -312,10 +312,10 @@ def line_search(f, xk, pk, jac=None, gfk=None, old_fval=None, old_old_fval=None,
|
|
|
312
312
|
Args:
|
|
313
313
|
f (function): function of the form f(x) where x is a flat Tensor and returns a real
|
|
314
314
|
scalar. The function should be composed of operations with vjp defined.
|
|
315
|
-
gf (function): the gradient function at x where x is a flat Tensor and returns a Tensor.
|
|
316
|
-
The function can be None if you want to use automatic credits.
|
|
317
315
|
xk (Tensor): initial guess.
|
|
318
316
|
pk (Tensor): direction to search in. Assumes the direction is a descent direction.
|
|
317
|
+
jac (function): the gradient function at x where x is a flat Tensor and returns a Tensor.
|
|
318
|
+
The function can be None if you want to use automatic credits.
|
|
319
319
|
gfk (Tensor): initial value of value_and_gradient as position. Default: ``None`` .
|
|
320
320
|
old_fval (Tensor): The same as `gfk`. Default: ``None`` .
|
|
321
321
|
old_old_fval (Tensor): unused argument, only for scipy API compliance. Default: ``None`` .
|
|
@@ -73,7 +73,7 @@ def minimize(func, x0, args=(), method=None, jac=None, hess=None, hessp=None, bo
|
|
|
73
73
|
- Gradients of ``func`` are calculated automatically using MindSpore's autodiff
|
|
74
74
|
support when the value of jac is None.
|
|
75
75
|
- The ``method`` argument is required. A exception will be thrown if you don't specify a solver.
|
|
76
|
-
- Various optional arguments `"hess"
|
|
76
|
+
- Various optional arguments `"hess"`, `"hessp"`, `"bounds"`, `"constraints"`, `"tol"`, `"callback"`
|
|
77
77
|
in the SciPy interface have not yet been implemented.
|
|
78
78
|
- Optimization results may differ from SciPy due to differences in the line
|
|
79
79
|
search implementation.
|
mindspore/train/__init__.py
CHANGED
|
@@ -26,7 +26,8 @@ from mindspore.train.amp import build_train_network
|
|
|
26
26
|
from mindspore.train.loss_scale_manager import LossScaleManager, FixedLossScaleManager, DynamicLossScaleManager
|
|
27
27
|
from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, \
|
|
28
28
|
load, parse_print, build_searched_strategy, merge_sliced_parameter, load_distributed_checkpoint, \
|
|
29
|
-
async_ckpt_thread_status, restore_group_info_list, convert_model, obfuscate_model, export_split_mindir
|
|
29
|
+
async_ckpt_thread_status, restore_group_info_list, convert_model, obfuscate_model, export_split_mindir, \
|
|
30
|
+
load_checkpoint_async
|
|
30
31
|
from mindspore.train.callback import Callback, LossMonitor, TimeMonitor, ModelCheckpoint, SummaryCollector, \
|
|
31
32
|
CheckpointConfig, RunContext, LearningRateScheduler, SummaryLandscape, \
|
|
32
33
|
History, LambdaCallback, ReduceLROnPlateau, EarlyStopping, OnRequestExit, BackupAndRestore
|
|
@@ -39,7 +40,7 @@ __all__ = ["Model", "DatasetHelper", "connect_network_with_dataset", "build_trai
|
|
|
39
40
|
"FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint",
|
|
40
41
|
"load_param_into_net", "export", "load", "export_split_mindir", "parse_print", "build_searched_strategy",
|
|
41
42
|
"merge_sliced_parameter", "load_distributed_checkpoint", "async_ckpt_thread_status",
|
|
42
|
-
"restore_group_info_list", "convert_model", "data_sink", "obfuscate_model"]
|
|
43
|
+
"restore_group_info_list", "convert_model", "data_sink", "obfuscate_model", "load_checkpoint_async"]
|
|
43
44
|
__all__.extend(callback.__all__)
|
|
44
45
|
__all__.extend(summary.__all__)
|
|
45
46
|
__all__.extend(train_thor.__all__)
|