mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.11__cp38-cp38-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +139 -22
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
- mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
- mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
- mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
- mindspore/_akg/akg/utils/kernel_exec.py +98 -274
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
- mindspore/_akg/akg/utils/util.py +56 -1
- mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +13 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +67 -72
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +86 -106
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +29 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +33 -7
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +61 -95
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +87 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +31 -19
- mindspore/ops/operations/_grad_ops.py +71 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +192 -144
- mindspore/ops/operations/nn_ops.py +857 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +42 -21
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/ops.py +55 -5
- mindspore/scipy/optimize/__init__.py +3 -2
- mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +488 -539
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
|
@@ -1,181 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
|
|
16
|
-
"""Define the grad rules of linalg related operations."""
|
|
17
|
-
from __future__ import absolute_import
|
|
18
|
-
|
|
19
|
-
import mindspore
|
|
20
|
-
|
|
21
|
-
from mindspore.ops import Tensor
|
|
22
|
-
from mindspore.ops import functional as F
|
|
23
|
-
from mindspore.ops import operations as P
|
|
24
|
-
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
25
|
-
from mindspore.ops.operations import math_ops as math
|
|
26
|
-
from mindspore.ops.operations import linalg_ops as linalg
|
|
27
|
-
from mindspore.ops.operations import array_ops as arrays
|
|
28
|
-
from mindspore.ops.primitive import constexpr, _primexpr
|
|
29
|
-
from mindspore.ops._grad_experimental.grad_base import bprop_getters
|
|
30
|
-
|
|
31
|
-
_shape = arrays.Shape()
|
|
32
|
-
|
|
33
|
-
_dtype = arrays.DType()
|
|
34
|
-
_cast = arrays.Cast()
|
|
35
|
-
_transpose = arrays.Transpose()
|
|
36
|
-
|
|
37
|
-
_conj = math.Conj()
|
|
38
|
-
_reciprocal = math.Reciprocal()
|
|
39
|
-
|
|
40
|
-
_k_0 = Tensor(0, mindspore.int32)
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
@_primexpr
|
|
44
|
-
def _check_dim(dim):
|
|
45
|
-
if dim < 2:
|
|
46
|
-
raise ValueError(f"The dim can not be less than 2, which is {dim}.")
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
@_primexpr
|
|
50
|
-
def generate_perm_for_matrix_transpose(input_dim):
|
|
51
|
-
perm = tuple(range(input_dim - 2))
|
|
52
|
-
perm = perm + (input_dim - 1, input_dim - 2)
|
|
53
|
-
return perm
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
def _matrix_transpose(a):
|
|
57
|
-
"""Transpose last two axes"""
|
|
58
|
-
a_shape = _shape(a)
|
|
59
|
-
if F.is_sequence_value_unknown(a_shape):
|
|
60
|
-
dim = P.Rank()(a)
|
|
61
|
-
perm = P.Range()(P.Cast()(0, mindspore.int64), P.Cast()(dim, mindspore.int64), P.Cast()(1, mindspore.int64))
|
|
62
|
-
perm = P.Concat(axis=-1)((perm[:-2], perm[-1:], perm[-2:-1]))
|
|
63
|
-
else:
|
|
64
|
-
dim = P.Rank()(a)
|
|
65
|
-
_check_dim(dim)
|
|
66
|
-
perm = generate_perm_for_matrix_transpose(dim)
|
|
67
|
-
return _transpose(a, perm)
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
def _adjoint(a):
|
|
71
|
-
return _matrix_transpose(_conj(a))
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
def _safe_reciprocal(x, epsilon=1e-20):
|
|
75
|
-
return x * _reciprocal(x * x + epsilon)
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
@constexpr
|
|
79
|
-
def _make_tensor(value, dtype):
|
|
80
|
-
return Tensor(value, dtype)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
def _matrix_diag(diagonal):
|
|
84
|
-
"""Do matrix diagnoal"""
|
|
85
|
-
diagonal_shape = _shape(diagonal)
|
|
86
|
-
if F.is_sequence_value_unknown(diagonal_shape):
|
|
87
|
-
row = P.Cast()(diagonal_shape[-1], mindspore.int32)
|
|
88
|
-
return arrays.MatrixDiagV3()(diagonal, _k_0, row, row, P.Cast()(0, _dtype(diagonal)))
|
|
89
|
-
|
|
90
|
-
row = _make_tensor(diagonal_shape[-1], mindspore.int32)
|
|
91
|
-
return arrays.MatrixDiagV3()(diagonal, _k_0, row, row, _make_tensor(0, _dtype(diagonal)))
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
def _mat_mul(x, y):
|
|
95
|
-
"""Do matmul"""
|
|
96
|
-
tensor_rank = P.Rank()(x)
|
|
97
|
-
if tensor_rank > 2:
|
|
98
|
-
return math.BatchMatMul()(x, y)
|
|
99
|
-
return math.MatMul()(x, y)
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
@bprop_getters.register(linalg.Svd)
|
|
103
|
-
def get_bprop_svd(self):
|
|
104
|
-
"""Generate bprop for Svd"""
|
|
105
|
-
full_matrices = self.full_matrices
|
|
106
|
-
compute_uv = self.compute_uv
|
|
107
|
-
|
|
108
|
-
svd = linalg.Svd(compute_uv=True)
|
|
109
|
-
square = math.Square()
|
|
110
|
-
matrix_set_diag = arrays.MatrixSetDiagV3()
|
|
111
|
-
expand_dims = arrays.ExpandDims()
|
|
112
|
-
|
|
113
|
-
def bprop(a, out, dout):
|
|
114
|
-
if not compute_uv:
|
|
115
|
-
s, u, v = svd(a)
|
|
116
|
-
da = _mat_mul(u, _mat_mul(_matrix_diag(dout[0]), _adjoint(v)))
|
|
117
|
-
return (da,)
|
|
118
|
-
|
|
119
|
-
a_shape = _shape(a)
|
|
120
|
-
tensor_rank = P.Rank()(a)
|
|
121
|
-
_check_dim(tensor_rank)
|
|
122
|
-
m = a_shape[-2]
|
|
123
|
-
n = a_shape[-1]
|
|
124
|
-
s, u, v = out
|
|
125
|
-
ds, du, dv = dout
|
|
126
|
-
use_adjoint = False
|
|
127
|
-
if m > n:
|
|
128
|
-
use_adjoint = True
|
|
129
|
-
m, n = n, m
|
|
130
|
-
u, v = v, u
|
|
131
|
-
du, dv = dv, du
|
|
132
|
-
|
|
133
|
-
if full_matrices and abs(m - n) > 1:
|
|
134
|
-
raise ValueError("For 'Svd' gradient, not support for abs(m - n) > 1 with full_matrices is True.")
|
|
135
|
-
|
|
136
|
-
s_mat = _matrix_diag(s)
|
|
137
|
-
s2 = square(s)
|
|
138
|
-
|
|
139
|
-
f = matrix_set_diag(_safe_reciprocal(expand_dims(s2, -2) - expand_dims(s2, -1)), zeros_like(s), _k_0)
|
|
140
|
-
s_inv_mat = _matrix_diag(_safe_reciprocal(s))
|
|
141
|
-
|
|
142
|
-
v1 = v[..., :, :m]
|
|
143
|
-
dv1 = dv[..., :, :m]
|
|
144
|
-
|
|
145
|
-
u_gu = _mat_mul(_adjoint(u), du)
|
|
146
|
-
v_gv = _mat_mul(_adjoint(v1), dv1)
|
|
147
|
-
|
|
148
|
-
f_u = f * u_gu
|
|
149
|
-
f_v = f * v_gv
|
|
150
|
-
ds_mat = _matrix_diag(_cast(ds, _dtype(a)))
|
|
151
|
-
term1_nouv = (ds_mat + _mat_mul(f_u + _adjoint(f_u), s_mat) + _mat_mul(s_mat, f_v + _adjoint(f_v)))
|
|
152
|
-
|
|
153
|
-
term1 = _mat_mul(u, _mat_mul(term1_nouv, _adjoint(v1)))
|
|
154
|
-
|
|
155
|
-
if m == n:
|
|
156
|
-
da_before_transpose = term1
|
|
157
|
-
else:
|
|
158
|
-
gv1t = _matrix_transpose(dv1)
|
|
159
|
-
gv1t_v1 = _mat_mul(gv1t, v1)
|
|
160
|
-
term2_nous = gv1t - _mat_mul(gv1t_v1, _adjoint(v1))
|
|
161
|
-
|
|
162
|
-
if full_matrices:
|
|
163
|
-
v2 = v[..., :, m:n]
|
|
164
|
-
d_v2 = dv[..., :, m:n]
|
|
165
|
-
|
|
166
|
-
v1t_gv2 = _mat_mul(_adjoint(v1), d_v2)
|
|
167
|
-
term2_nous -= _mat_mul(v1t_gv2, _adjoint(v2))
|
|
168
|
-
|
|
169
|
-
u_s_inv = _mat_mul(u, s_inv_mat)
|
|
170
|
-
term2 = _mat_mul(u_s_inv, term2_nous)
|
|
171
|
-
|
|
172
|
-
da_before_transpose = term1 + term2
|
|
173
|
-
|
|
174
|
-
if use_adjoint:
|
|
175
|
-
da = _matrix_transpose(da_before_transpose)
|
|
176
|
-
else:
|
|
177
|
-
da = da_before_transpose
|
|
178
|
-
|
|
179
|
-
return (da,)
|
|
180
|
-
|
|
181
|
-
return bprop
|
|
@@ -1,72 +0,0 @@
|
|
|
1
|
-
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
|
|
16
|
-
"""Generate bprop for other ops"""
|
|
17
|
-
|
|
18
|
-
from mindspore.ops import operations as P
|
|
19
|
-
from mindspore.ops.operations import _grad_ops as G
|
|
20
|
-
from mindspore.ops.operations import _inner_ops as inner
|
|
21
|
-
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
22
|
-
from mindspore.ops._grad_experimental.grad_base import bprop_getters
|
|
23
|
-
|
|
24
|
-
# Unused parameters are placeholders.
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
@bprop_getters.register(P.InvertPermutation)
|
|
28
|
-
def get_bprop_invert_permutation(self):
|
|
29
|
-
"""Generate bprop for InvertPermutation"""
|
|
30
|
-
|
|
31
|
-
def bprop(x, out, dout):
|
|
32
|
-
return (zeros_like(x),)
|
|
33
|
-
return bprop
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
@bprop_getters.register(inner.SyncBatchNorm)
|
|
37
|
-
def get_bprop_sync_batch_norm(self):
|
|
38
|
-
"""Grad definition for `SyncBatchNorm` operation."""
|
|
39
|
-
input_grad = G.SyncBatchNormGrad(self.epsilon, self.group, self.device_num)
|
|
40
|
-
|
|
41
|
-
def bprop(x, scale, b, mean, variance, out, dout):
|
|
42
|
-
saved_mean = out[3]
|
|
43
|
-
saved_variance = out[4]
|
|
44
|
-
out = input_grad(dout[0], x, scale, saved_mean, saved_variance)
|
|
45
|
-
dx = out[0]
|
|
46
|
-
dscale = out[1]
|
|
47
|
-
dbias = out[2]
|
|
48
|
-
res = (dx, dscale, dbias, zeros_like(mean), zeros_like(variance))
|
|
49
|
-
return res
|
|
50
|
-
return bprop
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
@bprop_getters.register(inner.GpuConvertToDynamicShape)
|
|
54
|
-
def get_bprop_gpu_convert_to_dynamic_shape(self):
|
|
55
|
-
"""Get backprop for GpuConvertToDynamicShape."""
|
|
56
|
-
|
|
57
|
-
def bprop(x, out, dout):
|
|
58
|
-
return (dout,)
|
|
59
|
-
return bprop
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
@bprop_getters.register(P._DynamicLossScale) # pylint: disable=W0212
|
|
63
|
-
def get_bprop_dynamic_loss_scale(self):
|
|
64
|
-
"""Get backprop for dynamic_loss_scale."""
|
|
65
|
-
mul = P.Mul()
|
|
66
|
-
mul.add_prim_attr('split_overflow', True)
|
|
67
|
-
mul.add_prim_attr('layer_overflow', self.layer)
|
|
68
|
-
|
|
69
|
-
def bprop(x, loss_scale, out, dout):
|
|
70
|
-
res = mul(dout, loss_scale)
|
|
71
|
-
return res, zeros_like(loss_scale)
|
|
72
|
-
return bprop
|
|
@@ -1,112 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
|
|
16
|
-
"""Generate bprop for quantization aware ops"""
|
|
17
|
-
|
|
18
|
-
from mindspore.ops.operations import _scalar_ops
|
|
19
|
-
from mindspore.ops._grad_experimental.grad_base import bprop_getters
|
|
20
|
-
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
@bprop_getters.register(_scalar_ops.ScalarAdd)
|
|
24
|
-
def get_bprop_scalar_add(self):
|
|
25
|
-
"""Grad definition for `ScalarAdd` operation."""
|
|
26
|
-
|
|
27
|
-
def bprop(x, y, out, dout):
|
|
28
|
-
return dout, dout
|
|
29
|
-
|
|
30
|
-
return bprop
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
@bprop_getters.register(_scalar_ops.ScalarSub)
|
|
34
|
-
def get_bprop_scalar_sub(self):
|
|
35
|
-
"""Grad definition for `ScalarSub` operation."""
|
|
36
|
-
|
|
37
|
-
def bprop(x, y, out, dout):
|
|
38
|
-
return dout, 0 - dout
|
|
39
|
-
|
|
40
|
-
return bprop
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
@bprop_getters.register(_scalar_ops.ScalarMul)
|
|
44
|
-
def get_bprop_scalar_mul(self):
|
|
45
|
-
"""Grad definition for `ScalarMul` operation."""
|
|
46
|
-
|
|
47
|
-
def bprop(x, y, out, dout):
|
|
48
|
-
bc_dx = y * dout
|
|
49
|
-
bc_dy = x * dout
|
|
50
|
-
return bc_dx, bc_dy
|
|
51
|
-
|
|
52
|
-
return bprop
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
@bprop_getters.register(_scalar_ops.ScalarDiv)
|
|
56
|
-
def get_bprop_scalar_div(self):
|
|
57
|
-
"""Grad definition for `ScalarDiv` operation."""
|
|
58
|
-
|
|
59
|
-
def bprop(x, y, out, dout):
|
|
60
|
-
bc_dx = dout / y
|
|
61
|
-
bc_dy = 0 - bc_dx * out
|
|
62
|
-
return bc_dx, bc_dy
|
|
63
|
-
|
|
64
|
-
return bprop
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
@bprop_getters.register(_scalar_ops.ScalarFloordiv)
|
|
68
|
-
def get_bprop_scalar_floordiv(self):
|
|
69
|
-
"""Grad definition for `ScalarFloorDiv` operation."""
|
|
70
|
-
|
|
71
|
-
def bprop(x, y, out, dout):
|
|
72
|
-
return zeros_like(x), zeros_like(y)
|
|
73
|
-
|
|
74
|
-
return bprop
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
@bprop_getters.register(_scalar_ops.ScalarMod)
|
|
78
|
-
def get_bprop_scalar_mod(self):
|
|
79
|
-
"""Grad definition for `ScalarMod` operation."""
|
|
80
|
-
|
|
81
|
-
def bprop(x, y, out, dout):
|
|
82
|
-
bc_dx = dout
|
|
83
|
-
bc_dy = -dout * (x // y)
|
|
84
|
-
return bc_dx, bc_dy
|
|
85
|
-
|
|
86
|
-
return bprop
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
@bprop_getters.register(_scalar_ops.scalar_eq)
|
|
90
|
-
@bprop_getters.register(_scalar_ops.scalar_le)
|
|
91
|
-
@bprop_getters.register(_scalar_ops.scalar_lt)
|
|
92
|
-
@bprop_getters.register(_scalar_ops.scalar_ge)
|
|
93
|
-
@bprop_getters.register(_scalar_ops.scalar_gt)
|
|
94
|
-
@bprop_getters.register(_scalar_ops.bit_and)
|
|
95
|
-
@bprop_getters.register(_scalar_ops.bit_or)
|
|
96
|
-
def get_bprop_scalar_logic(self):
|
|
97
|
-
"""Grad definition for `ScalarLogicOps` operation."""
|
|
98
|
-
|
|
99
|
-
def bprop(x, y, out, dout):
|
|
100
|
-
return zeros_like(x), zeros_like(y)
|
|
101
|
-
|
|
102
|
-
return bprop
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
@bprop_getters.register(_scalar_ops.ScalarBool)
|
|
106
|
-
def get_bprop_scalar_bool(self):
|
|
107
|
-
"""Grad definition for `ScalarBool` operation."""
|
|
108
|
-
|
|
109
|
-
def bprop(x, out, dout):
|
|
110
|
-
return zeros_like(x)
|
|
111
|
-
|
|
112
|
-
return bprop
|
|
@@ -1,351 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
|
|
16
|
-
"""grad_sequence_ops"""
|
|
17
|
-
|
|
18
|
-
from mindspore.ops.operations import _sequence_ops as seq
|
|
19
|
-
from mindspore.ops import operations as P
|
|
20
|
-
from mindspore.ops import functional as F
|
|
21
|
-
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
22
|
-
from mindspore.ops._grad_experimental.grad_base import bprop_getters
|
|
23
|
-
from mindspore.ops.primitive import Primitive
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
tuple_setitem = Primitive("tuple_setitem")
|
|
27
|
-
list_setitem = Primitive("list_setitem")
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
@bprop_getters.register(seq.SequenceCount)
|
|
31
|
-
def get_bprop_count(self):
|
|
32
|
-
"""Generate bprop for SequenceCount"""
|
|
33
|
-
|
|
34
|
-
def bprop(x, y, out, dout):
|
|
35
|
-
return (zeros_like(x), zeros_like(y))
|
|
36
|
-
|
|
37
|
-
return bprop
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
@bprop_getters.register(seq.sequence_len)
|
|
41
|
-
def get_bprop_sequence_len(self):
|
|
42
|
-
"""Generate bprop for sequence_len"""
|
|
43
|
-
def bprop(x, out, dout):
|
|
44
|
-
return (zeros_like(x),)
|
|
45
|
-
|
|
46
|
-
return bprop
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
@bprop_getters.register(seq.SequenceAdd)
|
|
50
|
-
def get_bprop_sequence_add(self):
|
|
51
|
-
"""Generate bprop for SequenceAdd"""
|
|
52
|
-
def bprop(x, y, out, dout):
|
|
53
|
-
out_offset = seq.SequenceAddOffset()(x, y)
|
|
54
|
-
dx = seq.SequenceSlice()(dout, out_offset[0], len(x), 1)
|
|
55
|
-
dy = seq.SequenceSlice()(dout, out_offset[1], len(x) + len(y), 1)
|
|
56
|
-
|
|
57
|
-
return (dx, dy)
|
|
58
|
-
|
|
59
|
-
return bprop
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
@bprop_getters.register(seq.SequenceUnstack)
|
|
63
|
-
def get_bprop_sequence_unstack(self):
|
|
64
|
-
"""Generate bprop for SequenceUnstack"""
|
|
65
|
-
axis = self.axis
|
|
66
|
-
|
|
67
|
-
def bprop(x, out, dout):
|
|
68
|
-
seq_unstack_grad = seq.SequenceStack(axis)
|
|
69
|
-
out = seq_unstack_grad(dout)
|
|
70
|
-
return (out,)
|
|
71
|
-
|
|
72
|
-
return bprop
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
@bprop_getters.register(seq.SequenceSlice)
|
|
76
|
-
def get_bprop_slice(self):
|
|
77
|
-
"""Generate bprop for SequenceSlice"""
|
|
78
|
-
|
|
79
|
-
def bprop(x, start, stop, step, out, dout):
|
|
80
|
-
dx = seq.SequenceSliceGrad()(dout, x, start, stop, step)
|
|
81
|
-
res = (dx, zeros_like(start), zeros_like(stop), zeros_like(step))
|
|
82
|
-
return res
|
|
83
|
-
|
|
84
|
-
return bprop
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
@bprop_getters.register(seq.SequenceIndex)
|
|
88
|
-
def get_bprop_index(self):
|
|
89
|
-
"""Generate bprop for SequenceIndex"""
|
|
90
|
-
|
|
91
|
-
def bprop(x, y, start, end, out, dout):
|
|
92
|
-
res = (zeros_like(x), zeros_like(y), zeros_like(start), zeros_like(end))
|
|
93
|
-
return res
|
|
94
|
-
|
|
95
|
-
return bprop
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
@bprop_getters.register(seq.InSequence)
|
|
99
|
-
def get_bprop_insequence(self):
|
|
100
|
-
"""Generate bprop for InSequence"""
|
|
101
|
-
|
|
102
|
-
def bprop(x, y, out, dout):
|
|
103
|
-
return (zeros_like(x), seq.SequenceZerosLike()(y))
|
|
104
|
-
|
|
105
|
-
return bprop
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
@bprop_getters.register("tuple_equal")
|
|
109
|
-
@bprop_getters.register("list_equal")
|
|
110
|
-
def get_bprop_seq_equal(self):
|
|
111
|
-
"""Generate bprop for tuple_equal and list_equal"""
|
|
112
|
-
|
|
113
|
-
def bprop(x, y, out, dout):
|
|
114
|
-
return (zeros_like(x), zeros_like(y))
|
|
115
|
-
|
|
116
|
-
return bprop
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
@bprop_getters.register("shape_mul")
|
|
120
|
-
def get_bprop_shape_mul(self):
|
|
121
|
-
"""Generate bprop for tuple_equal and list_equal"""
|
|
122
|
-
|
|
123
|
-
def bprop(x, out, dout):
|
|
124
|
-
dx = seq.ShapeMulGrad()(x, dout)
|
|
125
|
-
return (dx,)
|
|
126
|
-
|
|
127
|
-
return bprop
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
@bprop_getters.register("tuple_setitem")
|
|
131
|
-
def get_bprop_tuple_setitem(self):
|
|
132
|
-
"""Generate bprop for TupleSetItem and ListSetItem"""
|
|
133
|
-
|
|
134
|
-
def bprop(x, idx, value, out, dout):
|
|
135
|
-
d_x = tuple_setitem(dout, idx, zeros_like(value))
|
|
136
|
-
d_value = dout[idx]
|
|
137
|
-
d_idx = 0
|
|
138
|
-
return (d_x, zeros_like(d_idx), d_value)
|
|
139
|
-
|
|
140
|
-
return bprop
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
@bprop_getters.register("list_setitem")
|
|
144
|
-
def get_bprop_list_setitem(self):
|
|
145
|
-
"""Generate bprop for TupleSetItem and ListSetItem"""
|
|
146
|
-
|
|
147
|
-
def bprop(x, idx, value, out, dout):
|
|
148
|
-
d_x = list_setitem(dout, idx, zeros_like(value))
|
|
149
|
-
d_value = dout[idx]
|
|
150
|
-
d_idx = 0
|
|
151
|
-
return (d_x, zeros_like(d_idx), d_value)
|
|
152
|
-
|
|
153
|
-
return bprop
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
@bprop_getters.register("ListInplaceReverse")
|
|
157
|
-
def get_bprop_list_inplace_reverse(self):
|
|
158
|
-
"""Generate bprop for list inplace reverse"""
|
|
159
|
-
|
|
160
|
-
def bprop(x, out, dout):
|
|
161
|
-
return (zeros_like(x),)
|
|
162
|
-
|
|
163
|
-
return bprop
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
@bprop_getters.register("ListInplaceExtend")
|
|
167
|
-
def get_bprop_list_inplace_extend(self):
|
|
168
|
-
"""Generate bprop for list inplace extend"""
|
|
169
|
-
|
|
170
|
-
def bprop(x, y, out, dout):
|
|
171
|
-
return (zeros_like(x), zeros_like(y))
|
|
172
|
-
|
|
173
|
-
return bprop
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
@bprop_getters.register("ListInplaceInsert")
|
|
177
|
-
def get_bprop_list_inplace_insert(self):
|
|
178
|
-
"""Generate bprop for list inplace insert"""
|
|
179
|
-
|
|
180
|
-
def bprop(x, index, target, out, dout):
|
|
181
|
-
return (zeros_like(x), zeros_like(index), zeros_like(target))
|
|
182
|
-
|
|
183
|
-
return bprop
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
@bprop_getters.register("ListInplacePop")
|
|
187
|
-
def get_bprop_list_inplace_pop(self):
|
|
188
|
-
"""Generate bprop for list inplace pop"""
|
|
189
|
-
|
|
190
|
-
def bprop(x, index, out, dout):
|
|
191
|
-
return (zeros_like(x), zeros_like(index))
|
|
192
|
-
|
|
193
|
-
return bprop
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
@bprop_getters.register(seq.ListAppend)
|
|
197
|
-
def get_bprop_list_append(self):
|
|
198
|
-
"""Generate bprop for ListAppend"""
|
|
199
|
-
|
|
200
|
-
def bprop(x, value, out, dout):
|
|
201
|
-
d_x = seq.ListAppendAndInsertGrad()(dout, -1)
|
|
202
|
-
return (d_x, zeros_like(value))
|
|
203
|
-
|
|
204
|
-
return bprop
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
@bprop_getters.register(seq.ListInsert)
|
|
208
|
-
def get_bprop_list_insert(self):
|
|
209
|
-
"""Generate bprop for ListInsert"""
|
|
210
|
-
|
|
211
|
-
def bprop(x, idx, value, out, dout):
|
|
212
|
-
d_x = seq.ListAppendAndInsertGrad()(dout, idx)
|
|
213
|
-
return (d_x, zeros_like(idx), zeros_like(value))
|
|
214
|
-
|
|
215
|
-
return bprop
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
@bprop_getters.register(seq.TupleToTensor)
|
|
219
|
-
def get_bprop_tuple_to_tensor(self):
|
|
220
|
-
"""Generate bprop for TupleToTensor"""
|
|
221
|
-
|
|
222
|
-
def bprop(x, dtype, out, dout):
|
|
223
|
-
tuple_type = F.typeof(x)
|
|
224
|
-
dout = P.Cast()(dout, tuple_type)
|
|
225
|
-
d_x = seq.TensorToTuple()(dout)
|
|
226
|
-
return (d_x, zeros_like(dtype))
|
|
227
|
-
|
|
228
|
-
return bprop
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
@bprop_getters.register(seq.ListToTensor)
|
|
232
|
-
def get_bprop_list_to_tensor(self):
|
|
233
|
-
"""Generate bprop for ListToTensor"""
|
|
234
|
-
|
|
235
|
-
def bprop(x, dtype, out, dout):
|
|
236
|
-
tuple_type = F.typeof(x)
|
|
237
|
-
dout = P.Cast()(dout, tuple_type)
|
|
238
|
-
d_x = seq.TensorToList()(dout)
|
|
239
|
-
return (d_x, zeros_like(dtype))
|
|
240
|
-
|
|
241
|
-
return bprop
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
@bprop_getters.register(P.ScalarToTensor)
|
|
245
|
-
def get_bprop_scalar_to_tensor(self):
|
|
246
|
-
"""Generate bprop for ScalarToTensor"""
|
|
247
|
-
|
|
248
|
-
def bprop(x, dtype, out, dout):
|
|
249
|
-
scalar_type = F.typeof(x)
|
|
250
|
-
dout = P.Cast()(dout, scalar_type)
|
|
251
|
-
d_x = seq.TensorToScalar()(dout)
|
|
252
|
-
return (d_x, zeros_like(dtype))
|
|
253
|
-
|
|
254
|
-
return bprop
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
@bprop_getters.register(seq.TensorToTuple)
|
|
258
|
-
def get_bprop_tensor_to_tuple(self):
|
|
259
|
-
"""Generate bprop for TensorToTuple"""
|
|
260
|
-
|
|
261
|
-
def bprop(x, out, dout):
|
|
262
|
-
dtype = F.typeof(x)
|
|
263
|
-
d_x = seq.TupleToTensor()(dout, dtype)
|
|
264
|
-
return (d_x,)
|
|
265
|
-
|
|
266
|
-
return bprop
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
@bprop_getters.register(seq.TensorToList)
|
|
270
|
-
def get_bprop_tensor_to_list(self):
|
|
271
|
-
"""Generate bprop for TensorToList"""
|
|
272
|
-
|
|
273
|
-
def bprop(x, out, dout):
|
|
274
|
-
dtype = F.typeof(x)
|
|
275
|
-
d_x = seq.ListToTensor()(dout, dtype)
|
|
276
|
-
return (d_x,)
|
|
277
|
-
|
|
278
|
-
return bprop
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
@bprop_getters.register(seq.TensorToScalar)
|
|
282
|
-
def get_bprop_tensor_to_scalar(self):
|
|
283
|
-
"""Generate bprop for TensorToScalar"""
|
|
284
|
-
|
|
285
|
-
def bprop(x, out, dout):
|
|
286
|
-
dtype = F.typeof(x)
|
|
287
|
-
d_x = P.ScalarToTensor()(dout, dtype)
|
|
288
|
-
return (d_x,)
|
|
289
|
-
|
|
290
|
-
return bprop
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
@bprop_getters.register("tuple_le")
|
|
294
|
-
@bprop_getters.register("tuple_lt")
|
|
295
|
-
@bprop_getters.register("list_le")
|
|
296
|
-
@bprop_getters.register("list_lt")
|
|
297
|
-
def get_bprop_less(self):
|
|
298
|
-
"""Generate bprop for SequenceLessThan and SequenceLessEqual"""
|
|
299
|
-
|
|
300
|
-
def bprop(x, y, out, dout):
|
|
301
|
-
return (zeros_like(x), zeros_like(y))
|
|
302
|
-
|
|
303
|
-
return bprop
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
@bprop_getters.register(seq.SequenceMul)
|
|
307
|
-
def get_bprop_mul(self):
|
|
308
|
-
"""Generate bprop for SequenceMul"""
|
|
309
|
-
|
|
310
|
-
def bprop(x, y, out, dout):
|
|
311
|
-
dx = x
|
|
312
|
-
if isinstance(x, tuple):
|
|
313
|
-
for i in range(len(x)):
|
|
314
|
-
dx = tuple_setitem(dx, i, dout[i])
|
|
315
|
-
else:
|
|
316
|
-
for i in range(len(x)):
|
|
317
|
-
dx = list_setitem(dx, i, dout[i])
|
|
318
|
-
return (dx, zeros_like(y))
|
|
319
|
-
|
|
320
|
-
return bprop
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
@bprop_getters.register(seq.SequenceMin)
|
|
324
|
-
@bprop_getters.register(seq.SequenceMax)
|
|
325
|
-
def get_bprop_max_min(self):
|
|
326
|
-
"""Generate bprop for SequenceMax and SequenceMax"""
|
|
327
|
-
|
|
328
|
-
def bprop(x, out, dout):
|
|
329
|
-
index = x.index(out)
|
|
330
|
-
if isinstance(x, tuple):
|
|
331
|
-
dx = tuple_setitem(zeros_like(x), index, dout)
|
|
332
|
-
else:
|
|
333
|
-
dx = list_setitem(zeros_like(x), index, dout)
|
|
334
|
-
return (dx,)
|
|
335
|
-
|
|
336
|
-
return bprop
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
@bprop_getters.register("tuple_greater_than")
|
|
340
|
-
@bprop_getters.register("list_greater_than")
|
|
341
|
-
@bprop_getters.register("tuple_greater_equal")
|
|
342
|
-
@bprop_getters.register("list_greater_equal")
|
|
343
|
-
def get_bprop_greater(self):
|
|
344
|
-
"""Generate bprop for tuple_greater_than, list_greater_than,
|
|
345
|
-
tuple_greater_equal, list_greater_equal.
|
|
346
|
-
"""
|
|
347
|
-
|
|
348
|
-
def bprop(x, y, out, dout):
|
|
349
|
-
return (zeros_like(x), zeros_like(y))
|
|
350
|
-
|
|
351
|
-
return bprop
|