mindspore 2.1.0__cp38-none-any.whl → 2.2.0__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +26 -32
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +12 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +61 -71
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +72 -95
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +13 -0
- mindspore/common/api.py +173 -258
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +240 -145
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +13 -2
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +143 -59
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +28 -5
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +11 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +59 -66
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +0 -14
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +316 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +21 -28
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +310 -207
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +82 -41
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +13 -18
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +22 -17
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +78 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +1 -2
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +10 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +4 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +273 -72
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +40 -2
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +167 -189
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -8
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +470 -251
- mindspore/ops/function/random_func.py +86 -56
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +235 -19
- mindspore/ops/operations/__init__.py +25 -17
- mindspore/ops/operations/_grad_ops.py +52 -7
- mindspore/ops/operations/_inner_ops.py +213 -12
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +64 -280
- mindspore/ops/operations/comm_ops.py +105 -57
- mindspore/ops/operations/custom_ops.py +10 -3
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/math_ops.py +185 -138
- mindspore/ops/operations/nn_ops.py +716 -492
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +14 -12
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +6 -10
- mindspore/parallel/shard.py +4 -4
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +17 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +104 -252
- mindspore/profiler/parser/ascend_msprof_generator.py +8 -8
- mindspore/profiler/parser/ascend_op_generator.py +5 -5
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +9 -6
- mindspore/profiler/parser/base_timeline_generator.py +9 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +14 -10
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +37 -21
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +2 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +139 -71
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +525 -577
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +2 -2
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +14 -7
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +83 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +185 -45
- mindspore/train/serialization.py +390 -150
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +14 -10
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/METADATA +6 -7
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/RECORD +447 -507
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
|
@@ -17,20 +17,15 @@
|
|
|
17
17
|
from __future__ import absolute_import
|
|
18
18
|
from mindspore.ops._grad_experimental.grad_base import get_bprop_fn, get_taylor_fprop_fn
|
|
19
19
|
from mindspore.ops._grad_experimental import grad_array_ops
|
|
20
|
-
from mindspore.ops._grad_experimental import grad_image_ops
|
|
21
20
|
from mindspore.ops._grad_experimental import grad_inner_ops
|
|
22
21
|
from mindspore.ops._grad_experimental import grad_nn_ops
|
|
23
22
|
from mindspore.ops._grad_experimental import grad_math_ops
|
|
24
|
-
from mindspore.ops._grad_experimental import grad_linalg_ops
|
|
25
23
|
from mindspore.ops._grad_experimental import grad_sparse
|
|
26
24
|
from mindspore.ops._grad_experimental import grad_sparse_ops
|
|
27
|
-
from mindspore.ops._grad_experimental import grad_scalar_ops
|
|
28
25
|
from mindspore.ops._grad_experimental import grad_comm_ops
|
|
29
26
|
from mindspore.ops._grad_experimental import grad_debug_ops
|
|
30
27
|
from mindspore.ops._grad_experimental import grad_implementations
|
|
31
|
-
from mindspore.ops._grad_experimental import grad_other_ops
|
|
32
28
|
from mindspore.ops._grad_experimental import grad_quant_ops
|
|
33
|
-
from mindspore.ops._grad_experimental import grad_sequence_ops
|
|
34
29
|
from mindspore.ops._grad_experimental import taylor_rule
|
|
35
30
|
|
|
36
31
|
__all__ = ['get_bprop_fn', 'get_taylor_fprop_fn']
|
|
@@ -398,7 +398,6 @@ def get_bprop_extract_volume_patches(self):
|
|
|
398
398
|
expend_dims = P.ExpandDims()
|
|
399
399
|
scatter_nd = P.ScatterNd()
|
|
400
400
|
slice_op = P.Slice()
|
|
401
|
-
fill = P.Fill()
|
|
402
401
|
dtype = P.DType()
|
|
403
402
|
cast = P.Cast()
|
|
404
403
|
matmul = P.MatMul()
|
|
@@ -466,7 +465,7 @@ def get_bprop_extract_volume_patches(self):
|
|
|
466
465
|
idx_tensor = concat((expend_dims(x_idx_patched, -1), expend_dims(out_idx, -1)))
|
|
467
466
|
idx_map = P.Reshape()(idx_tensor, (-1, 2))
|
|
468
467
|
sp_shape = (x_indices_num, out_indices_num)
|
|
469
|
-
sp_mat_full = scatter_nd(idx_map, fill(dtype(dout), (out_indices_num,), 1), sp_shape)
|
|
468
|
+
sp_mat_full = scatter_nd(idx_map, F.fill(dtype(dout), (out_indices_num,), 1), sp_shape)
|
|
470
469
|
sp_tensor = slice_op(sp_mat_full, (1, 0), (x_indices_num - 1, out_indices_num))
|
|
471
470
|
|
|
472
471
|
grad = P.Transpose()(dout, (0, 2, 3, 4, 1))
|
|
@@ -98,6 +98,7 @@ def get_bprop_send(self):
|
|
|
98
98
|
def bprop(x, out, dout):
|
|
99
99
|
dx = send_grad(virtual_input)
|
|
100
100
|
return (dx,)
|
|
101
|
+
|
|
101
102
|
return bprop
|
|
102
103
|
|
|
103
104
|
|
|
@@ -117,14 +118,17 @@ def get_bprop_receive(self):
|
|
|
117
118
|
else:
|
|
118
119
|
dx = depend(cast(out_tensor, F.dtype(x)), send_out)
|
|
119
120
|
return (dx,)
|
|
121
|
+
|
|
120
122
|
return bprop
|
|
121
123
|
|
|
122
124
|
|
|
123
125
|
@bprop_getters.register(_VirtualAdd)
|
|
124
126
|
def get_bprop_virtual_add(self):
|
|
125
127
|
"""Generate bprop for _VirtualAdd"""
|
|
128
|
+
|
|
126
129
|
def bprop(x, grad_accu, out, dout):
|
|
127
130
|
return (dout + grad_accu, zeros_like(grad_accu))
|
|
131
|
+
|
|
128
132
|
return bprop
|
|
129
133
|
|
|
130
134
|
|
|
@@ -181,7 +185,8 @@ def get_bprop_mirror_micro_step_operator(self):
|
|
|
181
185
|
scale = 1 / dev_num
|
|
182
186
|
|
|
183
187
|
all_reduce = AllReduce(group=group)
|
|
184
|
-
|
|
188
|
+
if "segment" in self.get_attr_dict():
|
|
189
|
+
all_reduce.add_prim_attr("segment", self.get_attr_dict()["segment"])
|
|
185
190
|
fusion = self.get_attr_dict()["fusion"]
|
|
186
191
|
all_reduce.add_prim_attr("fusion", fusion)
|
|
187
192
|
if hasattr(self, 'parameter'):
|
|
@@ -218,6 +223,7 @@ def get_bprop_mirror_micro_step_operator(self):
|
|
|
218
223
|
return (real_grad, cast(out_tensor, dtype(z)))
|
|
219
224
|
return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))), assign(z, real_grad))
|
|
220
225
|
return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))), assign_out)
|
|
226
|
+
|
|
221
227
|
return bprop
|
|
222
228
|
|
|
223
229
|
|
|
@@ -227,6 +233,7 @@ def get_bprop_broad_cast(self):
|
|
|
227
233
|
|
|
228
234
|
def bprop(x, out, dout):
|
|
229
235
|
return (dout,)
|
|
236
|
+
|
|
230
237
|
return bprop
|
|
231
238
|
|
|
232
239
|
|
|
@@ -306,6 +313,8 @@ def get_bprop_micro_step_all_gather(self):
|
|
|
306
313
|
if do_mirror:
|
|
307
314
|
scale = 1.0 / self.rank_size
|
|
308
315
|
all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion)
|
|
316
|
+
if "segment" in self.get_attr_dict():
|
|
317
|
+
all_reduce.add_prim_attr("segment", self.get_attr_dict()["segment"])
|
|
309
318
|
rank = get_rank(self.group)
|
|
310
319
|
dev_num = get_group_size(self.group)
|
|
311
320
|
split = P.Split(output_num=dev_num)
|
|
@@ -502,6 +511,7 @@ def get_bprop_mirror_operator(self):
|
|
|
502
511
|
dx = RowTensorInner(indices, grad, dout.dense_shape)
|
|
503
512
|
|
|
504
513
|
return (dx,)
|
|
514
|
+
|
|
505
515
|
return bprop
|
|
506
516
|
|
|
507
517
|
|
|
@@ -555,6 +565,7 @@ def get_bprop_mirror_mini_step_operator(self):
|
|
|
555
565
|
dx = zeros_like(x) # The grad accumulation do not support row tensor now
|
|
556
566
|
|
|
557
567
|
return (dx, zeros_like(z))
|
|
568
|
+
|
|
558
569
|
return bprop
|
|
559
570
|
|
|
560
571
|
|
|
@@ -569,7 +580,7 @@ def get_bprop_virtual_div_operator(self):
|
|
|
569
580
|
def bprop(x, out, dout):
|
|
570
581
|
if issubclass_(F.typeof(dout), mstype.tensor_type):
|
|
571
582
|
if issubclass_(F.dtype(dout), mstype.bool_) or issubclass_(F.dtype(dout), mstype.int32) \
|
|
572
|
-
|
|
583
|
+
or issubclass_(F.dtype(dout), mstype.int16):
|
|
573
584
|
return (dout,)
|
|
574
585
|
dx = op(dout, cast(F.scalar_to_tensor(divisor), dtype(dout)))
|
|
575
586
|
return (dx,)
|
|
@@ -588,6 +599,7 @@ def get_bprop_virtual_div_operator(self):
|
|
|
588
599
|
ele_grad = op(dout[i], cast(F.scalar_to_tensor(divisor), dtype(dout[i])))
|
|
589
600
|
dx.append(ele_grad)
|
|
590
601
|
return (dx,)
|
|
602
|
+
|
|
591
603
|
return bprop
|
|
592
604
|
|
|
593
605
|
|
|
@@ -597,4 +609,5 @@ def get_bprop_get_tensor_slice_operator(self):
|
|
|
597
609
|
|
|
598
610
|
def bprop(x, dev_mat, tensor_map, out, dout):
|
|
599
611
|
return (zeros_like(x),)
|
|
612
|
+
|
|
600
613
|
return bprop
|
|
@@ -16,48 +16,11 @@
|
|
|
16
16
|
"""Generate bprop for debug ops"""
|
|
17
17
|
|
|
18
18
|
from mindspore.ops import operations as P
|
|
19
|
-
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
20
19
|
from mindspore.ops._grad_experimental.grad_base import bprop_getters
|
|
21
20
|
|
|
22
21
|
# Unused parameters are placeholders.
|
|
23
22
|
|
|
24
23
|
|
|
25
|
-
@bprop_getters.register(P.ScalarSummary)
|
|
26
|
-
def get_bprop_scalar_summary(self):
|
|
27
|
-
"""Generate bprop for ScalarSummary"""
|
|
28
|
-
|
|
29
|
-
def bprop(tag, x, out, dout):
|
|
30
|
-
return tag, zeros_like(x)
|
|
31
|
-
return bprop
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
@bprop_getters.register(P.TensorSummary)
|
|
35
|
-
def get_bprop_tensor_summary(self):
|
|
36
|
-
"""Generate bprop for TensorSummary"""
|
|
37
|
-
|
|
38
|
-
def bprop(tag, x, out, dout):
|
|
39
|
-
return tag, zeros_like(x)
|
|
40
|
-
return bprop
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
@bprop_getters.register(P.ImageSummary)
|
|
44
|
-
def get_bprop_image_summary(self):
|
|
45
|
-
"""Generate bprop for ImageSummary"""
|
|
46
|
-
|
|
47
|
-
def bprop(tag, x, out, dout):
|
|
48
|
-
return tag, zeros_like(x)
|
|
49
|
-
return bprop
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
@bprop_getters.register(P.HistogramSummary)
|
|
53
|
-
def get_bprop_histogram_summary(self):
|
|
54
|
-
"""Generate bprop for HistogramSummary"""
|
|
55
|
-
|
|
56
|
-
def bprop(tag, x, out, dout):
|
|
57
|
-
return tag, zeros_like(x)
|
|
58
|
-
return bprop
|
|
59
|
-
|
|
60
|
-
|
|
61
24
|
@bprop_getters.register(P.InsertGradientOf)
|
|
62
25
|
def get_bprop_insert_gradient_of(self):
|
|
63
26
|
"""Generate bprop for InsertGradientOf"""
|
|
@@ -191,3 +191,13 @@ def bprop_scalar_not(x, out, dout):
|
|
|
191
191
|
def bprop_tensor_move(x, out, dout):
|
|
192
192
|
"""Backpropagator for primitive `TensorMove`."""
|
|
193
193
|
return (dout,)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
@bprops.register("DictInplaceSetItem")
|
|
197
|
+
def get_bprop_dict_inplace_setitem(self):
|
|
198
|
+
"""Generate bprop for dict inplace pop"""
|
|
199
|
+
|
|
200
|
+
def bprop(x, key, target, out, dout):
|
|
201
|
+
return (zeros_like(x), zeros_like(key), zeros_like(target))
|
|
202
|
+
|
|
203
|
+
return bprop
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright 2021-2023 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -16,57 +16,11 @@
|
|
|
16
16
|
"""inner_ops"""
|
|
17
17
|
from __future__ import absolute_import
|
|
18
18
|
|
|
19
|
-
from mindspore.ops.operations.comm_ops import _VirtualPipelineEnd
|
|
20
19
|
from mindspore.ops.operations import _inner_ops as inner
|
|
21
20
|
from mindspore.ops.operations import _grad_ops as G
|
|
22
|
-
from mindspore.ops import functional as F
|
|
23
21
|
from mindspore.ops import operations as P
|
|
24
22
|
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
25
|
-
from mindspore.ops.
|
|
26
|
-
from mindspore.ops._grad_experimental.grad_base import bprop_getters, sum_grad_reduce_axis
|
|
27
|
-
import mindspore as ms
|
|
28
|
-
|
|
29
|
-
reshape = P.Reshape()
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
@bprop_getters.register(inner.TensorCopySlices)
|
|
33
|
-
def get_bprop_tensor_copy_slices(self):
|
|
34
|
-
"""Generate bprop for TensorCopySlices"""
|
|
35
|
-
tensor_copy_slices = inner.TensorCopySlices()
|
|
36
|
-
|
|
37
|
-
def bprop(x, update, begin, end, stride, out, dout):
|
|
38
|
-
x_grad = tensor_copy_slices(dout, zeros_like(update), begin, end, stride)
|
|
39
|
-
update_grad = F.strided_slice(dout, begin, end, stride)
|
|
40
|
-
res = (x_grad, update_grad, zeros_like(begin), zeros_like(end), zeros_like(stride))
|
|
41
|
-
return res
|
|
42
|
-
|
|
43
|
-
return bprop
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
@bprop_getters.register(_VirtualPipelineEnd)
|
|
47
|
-
def get_bprop_virtual_pipeline_end(self):
|
|
48
|
-
"""Backpropagator for _VirtualPipelineEnd."""
|
|
49
|
-
grad = _VirtualPipelineEnd()
|
|
50
|
-
|
|
51
|
-
def bprop(x, out, dout):
|
|
52
|
-
dx = grad(dout)
|
|
53
|
-
return (dx,)
|
|
54
|
-
return bprop
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
@bprop_getters.register(inner.DynamicResizeNearestNeighbor)
|
|
58
|
-
def get_bprop_dynamic_resize_nearest_neighbor(self):
|
|
59
|
-
"""Generate bprop for DynamicResizeNearestNeighbor"""
|
|
60
|
-
op = G.ResizeNearestNeighborGrad(self.align_corners)
|
|
61
|
-
shape_op = P.Shape()
|
|
62
|
-
|
|
63
|
-
def bprop(inputs, size, out, dout):
|
|
64
|
-
shp = shape_op(inputs)
|
|
65
|
-
# 2 and 3 represent the height and width
|
|
66
|
-
shp = (shp[2:])
|
|
67
|
-
return (op(dout, shp), zeros_like(size))
|
|
68
|
-
|
|
69
|
-
return bprop
|
|
23
|
+
from mindspore.ops._grad_experimental.grad_base import bprop_getters
|
|
70
24
|
|
|
71
25
|
|
|
72
26
|
@bprop_getters.register(inner.ParallelResizeBilinear)
|
|
@@ -82,115 +36,6 @@ def get_bprop_parallel_resize_bilinear(self):
|
|
|
82
36
|
return bprop
|
|
83
37
|
|
|
84
38
|
|
|
85
|
-
@bprop_getters.register(inner.ConvertToDynamic)
|
|
86
|
-
def get_bprop_gpu_convert_to_dynamic_rank(self):
|
|
87
|
-
"""Get backprop for ConvertToDynamic."""
|
|
88
|
-
|
|
89
|
-
def bprop(x, out, dout):
|
|
90
|
-
return (dout,)
|
|
91
|
-
return bprop
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
def _get_matrix_diag_assist(x_shape, x_dtype):
|
|
95
|
-
base_eye = P.Eye()(x_shape[-1], x_shape[-1], x_dtype).flatten()
|
|
96
|
-
tile = P.Tile()(base_eye, x_shape[:-1])
|
|
97
|
-
assist = P.Reshape()(tile, x_shape + (x_shape[-1],))
|
|
98
|
-
return assist
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
def _get_matrix_diag_part_assist(x_shape, x_dtype):
|
|
102
|
-
base_eye = P.Eye()(x_shape[-2], x_shape[-1], x_dtype).flatten()
|
|
103
|
-
tile = P.Tile()(base_eye, x_shape[:-2])
|
|
104
|
-
assist = P.Reshape()(tile, x_shape)
|
|
105
|
-
return assist
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
@_primexpr
|
|
109
|
-
def _get_min(x):
|
|
110
|
-
return min(x)
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
@bprop_getters.register(inner.MatrixDiag)
|
|
114
|
-
def get_bprop_matrix_diag(self):
|
|
115
|
-
"""Generate bprop for MatrixDiag"""
|
|
116
|
-
get_dtype = P.DType()
|
|
117
|
-
|
|
118
|
-
def bprop(x, y, out, dout):
|
|
119
|
-
shape = F.shape(dout)
|
|
120
|
-
dtype = get_dtype(dout)
|
|
121
|
-
assist = _get_matrix_diag_part_assist(shape, dtype)
|
|
122
|
-
dx = inner.MatrixDiagPart()(dout, assist)
|
|
123
|
-
return dx, zeros_like(y)
|
|
124
|
-
|
|
125
|
-
return bprop
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
@bprop_getters.register(inner.MatrixDiagPart)
|
|
129
|
-
def get_bprop_matrix_diag_part(self):
|
|
130
|
-
"""Generate bprop for MatrixDiagPart"""
|
|
131
|
-
get_dtype = P.DType()
|
|
132
|
-
|
|
133
|
-
def bprop(x, y, out, dout):
|
|
134
|
-
x_shape = F.shape(x)[-2:]
|
|
135
|
-
if x_shape[0] == x_shape[1]:
|
|
136
|
-
shape = F.shape(dout)
|
|
137
|
-
dtype = get_dtype(dout)
|
|
138
|
-
assist = _get_matrix_diag_assist(shape, dtype)
|
|
139
|
-
return inner.MatrixDiag()(dout, assist), zeros_like(y)
|
|
140
|
-
shape = F.shape(x)
|
|
141
|
-
dtype = get_dtype(x)
|
|
142
|
-
assist = _get_matrix_diag_part_assist(shape, dtype)
|
|
143
|
-
return inner.MatrixSetDiag()(zeros_like(x), dout, assist), zeros_like(y)
|
|
144
|
-
|
|
145
|
-
return bprop
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
@bprop_getters.register(inner.MatrixSetDiag)
|
|
149
|
-
def get_bprop_matrix_set_diag(self):
|
|
150
|
-
"""Generate bprop for MatrixSetDiag"""
|
|
151
|
-
get_dtype = P.DType()
|
|
152
|
-
|
|
153
|
-
def bprop(x, y, z, out, dout):
|
|
154
|
-
input_shape = F.shape(x)
|
|
155
|
-
batch_shape = input_shape[:-2]
|
|
156
|
-
matrix_shape = input_shape[-2:]
|
|
157
|
-
diag_shape = batch_shape + (_get_min(matrix_shape),)
|
|
158
|
-
|
|
159
|
-
grad_shape = F.shape(dout)
|
|
160
|
-
grad_dtype = get_dtype(dout)
|
|
161
|
-
assist = _get_matrix_diag_part_assist(grad_shape, grad_dtype)
|
|
162
|
-
dx = inner.MatrixSetDiag()(dout, P.Zeros()(diag_shape, grad_dtype), assist)
|
|
163
|
-
dy = inner.MatrixDiagPart()(dout, assist)
|
|
164
|
-
dz = zeros_like(z)
|
|
165
|
-
return dx, dy, dz
|
|
166
|
-
|
|
167
|
-
return bprop
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
@bprop_getters.register(inner.DSDMatmul)
|
|
171
|
-
def get_dsd_matmul_bprop(self):
|
|
172
|
-
def bprop(w1_gm, w2_gm, v_gm, out, dout):
|
|
173
|
-
d_w1_gm, d_w2_gm, d_v_gm = inner.DSDGrad()(w1_gm, w2_gm, v_gm, out, dout)
|
|
174
|
-
return d_w1_gm, d_w2_gm, d_v_gm
|
|
175
|
-
|
|
176
|
-
return bprop
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
@bprop_getters.register(inner.MatmulDDS)
|
|
180
|
-
def get_bprop(self):
|
|
181
|
-
"""brop of the matmulDDS operator"""
|
|
182
|
-
|
|
183
|
-
def bprop(q, k, local_mask, global_mask, out, d_out):
|
|
184
|
-
lc, gc = out
|
|
185
|
-
d_lc, d_gc = d_out
|
|
186
|
-
dq, dk = inner.MatmulDDSGrad()(q, k, lc, gc, d_lc, d_gc)
|
|
187
|
-
dk = P.Transpose()(dk, (1, 0, 3, 2))
|
|
188
|
-
all_d = (dq, dk, zeros_like(local_mask), zeros_like(global_mask))
|
|
189
|
-
return all_d
|
|
190
|
-
|
|
191
|
-
return bprop
|
|
192
|
-
|
|
193
|
-
|
|
194
39
|
@bprop_getters.register(inner.PsROIPooling)
|
|
195
40
|
def get_bprop_ps_roi_pooling(self):
|
|
196
41
|
"""Grad definition for `PsROIPooling` operation."""
|
|
@@ -224,62 +69,3 @@ def get_bprop_ps_roi_pooling(self):
|
|
|
224
69
|
return dx, zeros_like(rois)
|
|
225
70
|
|
|
226
71
|
return bprop
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
@bprop_getters.register(inner.DynamicBroadcastTo)
|
|
230
|
-
def get_bprop_dynamic_broadcast_to(self):
|
|
231
|
-
"""Generate bprop for DynamicBroadcastTo"""
|
|
232
|
-
shape_op = P.Shape()
|
|
233
|
-
|
|
234
|
-
def bprop(x, shp, out, dout):
|
|
235
|
-
x_shape = shape_op(x)
|
|
236
|
-
broadcast_shape = shape_op(out)
|
|
237
|
-
|
|
238
|
-
_, reduction_axes = inner.DynamicBroadcastGradientArgs()(broadcast_shape, x_shape)
|
|
239
|
-
out_type = dout.dtype
|
|
240
|
-
if out_type in (ms.int16, ms.int32, ms.int64):
|
|
241
|
-
dout = P.Cast()(dout, ms.float32)
|
|
242
|
-
reduced_grad = sum_grad_reduce_axis(dout, reduction_axes, keep_dims=True)
|
|
243
|
-
reduced_grad = P.Cast()(reduced_grad, out_type)
|
|
244
|
-
else:
|
|
245
|
-
reduced_grad = sum_grad_reduce_axis(dout, reduction_axes, keep_dims=True)
|
|
246
|
-
dx = reshape(reduced_grad, x_shape)
|
|
247
|
-
return dx, zeros_like(shp)
|
|
248
|
-
|
|
249
|
-
return bprop
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
@bprop_getters.register(inner.ConvertToAdapterTensor)
|
|
253
|
-
def get_bprop_convert_to_adapter_tensor(self):
|
|
254
|
-
"""Generate bprop for ConvertToAdapterTensor"""
|
|
255
|
-
|
|
256
|
-
def bprop(x, out, dout):
|
|
257
|
-
return (dout,)
|
|
258
|
-
|
|
259
|
-
return bprop
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
@bprop_getters.register(inner.ConvertToMsTensor)
|
|
263
|
-
def get_bprop_convert_to_ms_tensor(self):
|
|
264
|
-
"""Generate bprop for ConvertToMsTensor"""
|
|
265
|
-
|
|
266
|
-
def bprop(x, out, dout):
|
|
267
|
-
return (dout,)
|
|
268
|
-
|
|
269
|
-
return bprop
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
@bprop_getters.register(inner.SiLU)
|
|
273
|
-
def get_bprop_silu(self):
|
|
274
|
-
"""Generate bprop for SiLU"""
|
|
275
|
-
sigmoid_grad = G.SigmoidGrad()
|
|
276
|
-
mul_func = P.Mul()
|
|
277
|
-
|
|
278
|
-
def bprop(x, out, dout):
|
|
279
|
-
sigmoid_input = P.Sigmoid()(x)
|
|
280
|
-
bc_dx = mul_func(x, dout)
|
|
281
|
-
bc_dy = mul_func(sigmoid_input, dout)
|
|
282
|
-
dx = sigmoid_grad(sigmoid_input, bc_dx)
|
|
283
|
-
return (dx+bc_dy,)
|
|
284
|
-
|
|
285
|
-
return bprop
|
|
@@ -34,14 +34,12 @@ from mindspore.ops.operations.math_ops import MatrixTriangularSolve
|
|
|
34
34
|
from mindspore.ops.operations.math_ops import NanToNum
|
|
35
35
|
from mindspore.ops.operations.math_ops import FFTWithSize
|
|
36
36
|
from mindspore.ops.operations.math_ops import Cholesky
|
|
37
|
-
from mindspore.ops.operations.math_ops import Fmin
|
|
38
37
|
from mindspore.ops.operations.math_ops import CholeskySolve
|
|
39
38
|
from mindspore.ops.operations.math_ops import InplaceIndexAdd
|
|
40
39
|
from mindspore.ops.operations.math_ops import TridiagonalSolve
|
|
41
40
|
from mindspore.ops.operations.math_ops import Diagonal
|
|
42
41
|
from mindspore.ops.operations.math_ops import EuclideanNorm
|
|
43
42
|
from mindspore.ops.operations.array_ops import Transpose, MatrixSetDiagV3
|
|
44
|
-
from mindspore.ops.operations.math_ops import Fmax
|
|
45
43
|
from mindspore.ops.operations._inner_ops import DynamicBroadcastGradientArgs
|
|
46
44
|
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
47
45
|
from mindspore.ops.primitive import _primexpr
|
|
@@ -424,115 +422,6 @@ def get_bprop_polar(self):
|
|
|
424
422
|
return bprop
|
|
425
423
|
|
|
426
424
|
|
|
427
|
-
@bprop_getters.register(Fmin)
|
|
428
|
-
def get_bprop_fmin(self):
|
|
429
|
-
"""Grad definition for 'Fmin' operation"""
|
|
430
|
-
shape_ = P.Shape()
|
|
431
|
-
masked_fill_op = P.MaskedFill()
|
|
432
|
-
logical_or_op = P.LogicalOr()
|
|
433
|
-
logical_not_op = P.LogicalNot()
|
|
434
|
-
logical_and_op = P.LogicalAnd()
|
|
435
|
-
mul_op = P.Mul()
|
|
436
|
-
is_nan_op = P.IsNan()
|
|
437
|
-
reshape_ = P.Reshape()
|
|
438
|
-
|
|
439
|
-
def bprop(x1, x2, out, dout):
|
|
440
|
-
x1_dtype = F.dtype(x1)
|
|
441
|
-
x2_dtype = F.dtype(x2)
|
|
442
|
-
x1 = F.cast(x1, mstype.float32)
|
|
443
|
-
x2 = F.cast(x2, mstype.float32)
|
|
444
|
-
dout = F.cast(dout, mstype.float32)
|
|
445
|
-
b1 = logical_or_op((x1 <= x2), is_nan_op(x2))
|
|
446
|
-
b2 = logical_or_op((x2 < x1), logical_and_op(is_nan_op(x1), logical_not_op(is_nan_op(x2))))
|
|
447
|
-
rx1 = masked_fill_op(x1, b1, 1.)
|
|
448
|
-
rx1 = masked_fill_op(rx1, logical_not_op(b1), 0.)
|
|
449
|
-
rx2 = masked_fill_op(x2, b2, 1.)
|
|
450
|
-
rx2 = masked_fill_op(rx2, logical_not_op(b2), 0.)
|
|
451
|
-
rrx1 = mul_op(rx1, dout)
|
|
452
|
-
rrx2 = mul_op(rx2, dout)
|
|
453
|
-
shape_of_x1 = shape_(x1)
|
|
454
|
-
shape_of_x2 = shape_(x2)
|
|
455
|
-
x1_dim = len(shape_of_x1)
|
|
456
|
-
x2_dim = len(shape_of_x2)
|
|
457
|
-
if x1_dim == 0 and x2_dim != 0:
|
|
458
|
-
sum_r1 = rrx1.sum()
|
|
459
|
-
sum_r2 = rrx2
|
|
460
|
-
elif x1_dim == 0 and x2_dim == 0:
|
|
461
|
-
sum_r1 = rrx1.sum()
|
|
462
|
-
sum_r2 = rrx2.sum()
|
|
463
|
-
elif x1_dim != 0 and x2_dim == 0:
|
|
464
|
-
sum_r2 = rrx2.sum()
|
|
465
|
-
sum_r1 = rrx1
|
|
466
|
-
else:
|
|
467
|
-
rx, ry = DynamicBroadcastGradientArgs()(shape_of_x1, shape_of_x2)
|
|
468
|
-
sum_r1 = sum_grad_reduce_axis(rrx1, rx)
|
|
469
|
-
sum_r2 = sum_grad_reduce_axis(rrx2, ry)
|
|
470
|
-
brrx1 = reshape_(sum_r1, shape_of_x1)
|
|
471
|
-
brrx2 = reshape_(sum_r2, shape_of_x2)
|
|
472
|
-
brrx1 = F.cast(brrx1, x1_dtype)
|
|
473
|
-
brrx2 = F.cast(brrx2, x2_dtype)
|
|
474
|
-
return brrx1, brrx2
|
|
475
|
-
|
|
476
|
-
return bprop
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
@bprop_getters.register(Fmax)
|
|
480
|
-
def get_bprop_fmax(self):
|
|
481
|
-
"""Grad definition for 'Fmax' operation"""
|
|
482
|
-
shape_ = P.Shape()
|
|
483
|
-
masked_fill_op = P.MaskedFill()
|
|
484
|
-
logical_or_op = P.LogicalOr()
|
|
485
|
-
logical_not_op = P.LogicalNot()
|
|
486
|
-
logical_and_op = P.LogicalAnd()
|
|
487
|
-
mul_op = P.Mul()
|
|
488
|
-
is_nan_op = P.IsNan()
|
|
489
|
-
reshape_ = P.Reshape()
|
|
490
|
-
|
|
491
|
-
def bprop(x1, x2, out, dout):
|
|
492
|
-
x1_dtype = F.dtype(x1)
|
|
493
|
-
x2_dtype = F.dtype(x2)
|
|
494
|
-
if x1_dtype != mstype.float32:
|
|
495
|
-
x1 = F.cast(x1, mstype.float32)
|
|
496
|
-
dout = F.cast(dout, mstype.float32)
|
|
497
|
-
if x2_dtype != mstype.float32:
|
|
498
|
-
x2 = F.cast(x2, mstype.float32)
|
|
499
|
-
dout = F.cast(dout, mstype.float32)
|
|
500
|
-
b1 = logical_or_op(logical_and_op((x1 >= x2), logical_not_op(is_nan_op(x1))), is_nan_op(x2))
|
|
501
|
-
b2 = logical_or_op(logical_and_op(x2 > x1, logical_not_op(is_nan_op(x2))),
|
|
502
|
-
logical_and_op(is_nan_op(x1), logical_not_op(is_nan_op(x2))))
|
|
503
|
-
rx1 = masked_fill_op(x1, b1, 1.)
|
|
504
|
-
rx1 = masked_fill_op(rx1, logical_not_op(b1), 0.)
|
|
505
|
-
rx2 = masked_fill_op(x2, b2, 1.)
|
|
506
|
-
rx2 = masked_fill_op(rx2, logical_not_op(b2), 0.)
|
|
507
|
-
rrx1 = mul_op(rx1, dout)
|
|
508
|
-
rrx2 = mul_op(rx2, dout)
|
|
509
|
-
shape_of_x1 = shape_(x1)
|
|
510
|
-
shape_of_x2 = shape_(x2)
|
|
511
|
-
x1_dim = len(shape_of_x1)
|
|
512
|
-
x2_dim = len(shape_of_x2)
|
|
513
|
-
if x1_dim == 0 and x2_dim != 0:
|
|
514
|
-
sum_r1 = rrx1.sum()
|
|
515
|
-
sum_r2 = rrx2
|
|
516
|
-
elif x1_dim == 0 and x2_dim == 0:
|
|
517
|
-
sum_r1 = rrx1.sum()
|
|
518
|
-
sum_r2 = rrx2.sum()
|
|
519
|
-
elif x1_dim != 0 and x2_dim == 0:
|
|
520
|
-
sum_r2 = rrx2.sum()
|
|
521
|
-
sum_r1 = rrx1
|
|
522
|
-
else:
|
|
523
|
-
rx, ry = DynamicBroadcastGradientArgs()(shape_of_x1, shape_of_x2)
|
|
524
|
-
sum_r1 = sum_grad_reduce_axis(rrx1, rx)
|
|
525
|
-
sum_r2 = sum_grad_reduce_axis(rrx2, ry)
|
|
526
|
-
brrx1 = reshape_(sum_r1, shape_of_x1)
|
|
527
|
-
brrx2 = reshape_(sum_r2, shape_of_x2)
|
|
528
|
-
brrx1 = F.cast(brrx1, x1_dtype)
|
|
529
|
-
brrx2 = F.cast(brrx2, x2_dtype)
|
|
530
|
-
return brrx1, brrx2
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
return bprop
|
|
534
|
-
|
|
535
|
-
|
|
536
425
|
@bprop_getters.register(TridiagonalSolve)
|
|
537
426
|
def get_bprop_tridiagonalsolve(self):
|
|
538
427
|
"""Grad definition for 'TridiagonalSolve' operation"""
|
|
@@ -1127,73 +1016,3 @@ def get_bprop_tensor_add(self):
|
|
|
1127
1016
|
return binop_grad_common(x, y, dout, dout)
|
|
1128
1017
|
|
|
1129
1018
|
return bprop
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
@bprop_getters.register(P.BitwiseAnd)
|
|
1133
|
-
def get_bprop_bitwiseand(self):
|
|
1134
|
-
"""Grad definition for `BitwiseAnd` operation."""
|
|
1135
|
-
|
|
1136
|
-
def bprop(x, y, out, dout):
|
|
1137
|
-
return zeros_like(x), zeros_like(y)
|
|
1138
|
-
|
|
1139
|
-
return bprop
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
@bprop_getters.register(P.BitwiseOr)
|
|
1143
|
-
def get_bprop_bitwiseor(self):
|
|
1144
|
-
"""Grad definition for `BitwiseOr` operation."""
|
|
1145
|
-
|
|
1146
|
-
def bprop(x, y, out, dout):
|
|
1147
|
-
return zeros_like(x), zeros_like(y)
|
|
1148
|
-
|
|
1149
|
-
return bprop
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
@bprop_getters.register(P.BitwiseXor)
|
|
1153
|
-
def get_bprop_bitwisexor(self):
|
|
1154
|
-
"""Grad definition for `BitwiseXor` operation."""
|
|
1155
|
-
|
|
1156
|
-
def bprop(x, y, out, dout):
|
|
1157
|
-
return zeros_like(x), zeros_like(y)
|
|
1158
|
-
|
|
1159
|
-
return bprop
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
@bprop_getters.register(P.InplaceUpdate)
|
|
1163
|
-
def get_bprop_inplace_update(self):
|
|
1164
|
-
"""Grad definition for `InplaceUpdate` operation."""
|
|
1165
|
-
|
|
1166
|
-
def bprop(x, v, out, dout):
|
|
1167
|
-
return zeros_like(x), zeros_like(v)
|
|
1168
|
-
|
|
1169
|
-
return bprop
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
@bprop_getters.register(P.InplaceUpdateV2)
|
|
1173
|
-
def get_bprop_inplace_update_v2(self):
|
|
1174
|
-
"""Grad definition for `InplaceUpdateV2` operation."""
|
|
1175
|
-
|
|
1176
|
-
def bprop(x, indices, v, out, dout):
|
|
1177
|
-
return zeros_like(x), zeros_like(indices), zeros_like(v)
|
|
1178
|
-
|
|
1179
|
-
return bprop
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
@bprop_getters.register(P.InplaceSub)
|
|
1183
|
-
def get_bprop_inplace_sub(self):
|
|
1184
|
-
"""Grad definition for `InplaceSub` operation."""
|
|
1185
|
-
|
|
1186
|
-
def bprop(x, input_v, out, dout):
|
|
1187
|
-
return zeros_like(x), zeros_like(input_v)
|
|
1188
|
-
|
|
1189
|
-
return bprop
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
@bprop_getters.register(P.InplaceAdd)
|
|
1193
|
-
def get_bprop_inplace_add(self):
|
|
1194
|
-
"""Grad definition for `InplaceAdd` operation."""
|
|
1195
|
-
|
|
1196
|
-
def bprop(x, input_v, out, dout):
|
|
1197
|
-
return zeros_like(x), zeros_like(input_v)
|
|
1198
|
-
|
|
1199
|
-
return bprop
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
1
16
|
"""Define the grad rules of math related operations."""
|
|
2
17
|
|
|
3
18
|
from mindspore.ops import functional as F
|
|
@@ -299,7 +299,7 @@ def dsdbpropimpl(w1_gm, w2_gm, v_gm, a_gm, d_a_gm, d_w1_gm={}, d_w2_gm={}, d_v_g
|
|
|
299
299
|
global_size // 16, 16, 16),
|
|
300
300
|
name='v_global_l0b', scope=tik.scope_cb)
|
|
301
301
|
|
|
302
|
-
# d_w_global
|
|
302
|
+
# d_w_global, 小z大n
|
|
303
303
|
d_w_global_l0c = tik_inst.Tensor('float32', (global_size // 16, head_size // (16 * ub_time), 16, 16),
|
|
304
304
|
name='d_w_global_l0c', scope=tik.scope_cc)
|
|
305
305
|
d_w_global_ub = tik_inst.Tensor('float16', (global_size // 16,
|