mindspore 2.1.0__cp37-cp37m-manylinux1_x86_64.whl → 2.2.11__cp37-cp37m-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +139 -22
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
- mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
- mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
- mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
- mindspore/_akg/akg/utils/kernel_exec.py +98 -274
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
- mindspore/_akg/akg/utils/util.py +56 -1
- mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +13 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +67 -72
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +86 -106
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +29 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +33 -7
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +61 -95
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +87 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +31 -19
- mindspore/ops/operations/_grad_ops.py +71 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +192 -144
- mindspore/ops/operations/nn_ops.py +857 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +42 -21
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/ops.py +55 -5
- mindspore/scipy/optimize/__init__.py +3 -2
- mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +488 -539
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
|
@@ -1,269 +0,0 @@
|
|
|
1
|
-
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""GraphKernel expander utils"""
|
|
16
|
-
from abc import ABCMeta, abstractmethod
|
|
17
|
-
from mindspore._extends.graph_kernel.model import model_builder as builder
|
|
18
|
-
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class Expander(metaclass=ABCMeta):
|
|
22
|
-
"""
|
|
23
|
-
Expander is the base class of expanders.
|
|
24
|
-
|
|
25
|
-
The method `_expand` should be overridden to implement the operator detail.
|
|
26
|
-
"""
|
|
27
|
-
def __init__(self, expand_info):
|
|
28
|
-
self.name = expand_info["name"]
|
|
29
|
-
self.inputs = expand_info["input_desc"]
|
|
30
|
-
self.outputs = expand_info["output_desc"]
|
|
31
|
-
self.attrs = expand_info["attr"]
|
|
32
|
-
self.processor = expand_info["process"]
|
|
33
|
-
|
|
34
|
-
def run(self):
|
|
35
|
-
"""
|
|
36
|
-
Expand the operator to a graph.
|
|
37
|
-
|
|
38
|
-
`GraphKernelUnsupportedException` would be raised if check failed.
|
|
39
|
-
"""
|
|
40
|
-
self._check()
|
|
41
|
-
graph_builder = builder.GraphBuilder()
|
|
42
|
-
with graph_builder.graph_scope(self.name) as graph_scope:
|
|
43
|
-
# transform input_desc to Tensor
|
|
44
|
-
self.inputs = [graph_builder.tensor(inp['shape'], inp['data_type'], inp['format']) for inp in self.inputs]
|
|
45
|
-
graph_scope.set_input(*self.inputs)
|
|
46
|
-
outputs = self._expand(graph_builder)
|
|
47
|
-
if isinstance(outputs, (list, tuple)):
|
|
48
|
-
self._check_output_same(outputs)
|
|
49
|
-
graph_scope.set_output(*outputs)
|
|
50
|
-
else:
|
|
51
|
-
self._check_output_same([outputs])
|
|
52
|
-
graph_scope.set_output(outputs)
|
|
53
|
-
|
|
54
|
-
graph = graph_builder.get()[0]
|
|
55
|
-
graph.set_processor(self.processor)
|
|
56
|
-
return graph
|
|
57
|
-
|
|
58
|
-
def _check(self):
|
|
59
|
-
"""Check inputs"""
|
|
60
|
-
|
|
61
|
-
def _check_output_same(self, outputs):
|
|
62
|
-
for index, value in enumerate(self.outputs):
|
|
63
|
-
if list(outputs[index].shape) != list(value['shape']):
|
|
64
|
-
raise GKException("{} 's output shape {} is wrong. Expected:{}".format(
|
|
65
|
-
self.__class__.__name__, list(outputs[index].shape), list(value['shape'])))
|
|
66
|
-
if outputs[index].dtype != value['data_type']:
|
|
67
|
-
raise GKException("{} 's output data_type {} is wrong. Expected: {}".format(
|
|
68
|
-
self.__class__.__name__, outputs[index].dtype, value['data_type']))
|
|
69
|
-
if outputs[index].data_format != value['format']:
|
|
70
|
-
raise GKException("{} 's output format {} is wrong. Expected: {}".format(
|
|
71
|
-
self.__class__.__name__, outputs[index].data_format, value['format']))
|
|
72
|
-
|
|
73
|
-
@abstractmethod
|
|
74
|
-
def _expand(self, graph_builder):
|
|
75
|
-
"""Expand operator, this function should be overridden in subclass"""
|
|
76
|
-
raise Exception("_expand() is not implemented in {}".format(self.__class__.__name__))
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
class ExpanderInfoValidator:
|
|
80
|
-
"""ExpanderInfoValidator is the utility class which defines the validator decorator for expanders"""
|
|
81
|
-
|
|
82
|
-
def __init__(self):
|
|
83
|
-
"""Init"""
|
|
84
|
-
|
|
85
|
-
@staticmethod
|
|
86
|
-
def _add_check_function(kls, func):
|
|
87
|
-
"""
|
|
88
|
-
Rewrite the function `_check` in class Expander
|
|
89
|
-
to append the new `func` after the original checks.
|
|
90
|
-
"""
|
|
91
|
-
old_check = getattr(kls, "_check")
|
|
92
|
-
|
|
93
|
-
def new_check(obj):
|
|
94
|
-
old_check(obj)
|
|
95
|
-
func(obj)
|
|
96
|
-
|
|
97
|
-
setattr(kls, "_check", new_check)
|
|
98
|
-
|
|
99
|
-
@staticmethod
|
|
100
|
-
def add_format(*input_format):
|
|
101
|
-
"""
|
|
102
|
-
Add new supported format for the operator
|
|
103
|
-
|
|
104
|
-
this function will add a list `__supported_formats` into the expander,
|
|
105
|
-
saving the whitelist of formats that this op supports.
|
|
106
|
-
it also rewrites the `_check` function to check the formats.
|
|
107
|
-
"""
|
|
108
|
-
format_list_name = "__supported_formats"
|
|
109
|
-
|
|
110
|
-
def _check_format(obj):
|
|
111
|
-
inp_formats = [inp['format'] for inp in obj.inputs]
|
|
112
|
-
for formats in getattr(obj, format_list_name):
|
|
113
|
-
if len(formats) != len(inp_formats):
|
|
114
|
-
raise GKException("For '{}', length of registered format is different from the length of inputs "
|
|
115
|
-
"format: {} vs {}".format(obj.name, len(formats), len(inp_formats)))
|
|
116
|
-
if all((fmt == inp for fmt, inp in zip(formats, inp_formats))):
|
|
117
|
-
return
|
|
118
|
-
raise GKException("Unregistered format ({}) for op {}".format(','.join(inp_formats), obj.name))
|
|
119
|
-
|
|
120
|
-
def wrapper(cls):
|
|
121
|
-
if not issubclass(cls, Expander):
|
|
122
|
-
raise Exception("{} should be subclass of Expander.".format(cls.__name__))
|
|
123
|
-
if not hasattr(cls, format_list_name):
|
|
124
|
-
setattr(cls, format_list_name, list())
|
|
125
|
-
ExpanderInfoValidator._add_check_function(cls, _check_format)
|
|
126
|
-
getattr(cls, format_list_name).append(input_format)
|
|
127
|
-
return cls
|
|
128
|
-
|
|
129
|
-
return wrapper
|
|
130
|
-
|
|
131
|
-
@staticmethod
|
|
132
|
-
def check_all_formats_same(kls):
|
|
133
|
-
"""Check that all formats are the same"""
|
|
134
|
-
|
|
135
|
-
# Ensure no args case can return a class
|
|
136
|
-
def _check(*args):
|
|
137
|
-
def _check_format(obj):
|
|
138
|
-
inp_formats = [inp['format'] for inp in obj.inputs]
|
|
139
|
-
if all((fmt == inp_formats[0] for fmt in inp_formats[1:])):
|
|
140
|
-
return
|
|
141
|
-
raise GKException("[check_all_formats_same] unmatched formats ({}) for op {}".format(
|
|
142
|
-
','.join(inp_formats), obj.name))
|
|
143
|
-
|
|
144
|
-
def wrapper(cls):
|
|
145
|
-
if not issubclass(cls, Expander):
|
|
146
|
-
raise Exception("{} should be subclass of Expander.".format(cls.__name__))
|
|
147
|
-
ExpanderInfoValidator._add_check_function(cls, _check_format)
|
|
148
|
-
return cls
|
|
149
|
-
|
|
150
|
-
return wrapper
|
|
151
|
-
|
|
152
|
-
return _check()(kls)
|
|
153
|
-
|
|
154
|
-
@staticmethod
|
|
155
|
-
def check_attrs(*args):
|
|
156
|
-
"""Check the attrs exist"""
|
|
157
|
-
|
|
158
|
-
def _check_attr(obj):
|
|
159
|
-
for a in args:
|
|
160
|
-
if a not in obj.attrs:
|
|
161
|
-
raise GKException("attr '{}' does not exist. {}".format(a, obj.name))
|
|
162
|
-
|
|
163
|
-
def wrapper(cls):
|
|
164
|
-
if not issubclass(cls, Expander):
|
|
165
|
-
raise Exception("{} should be subclass of Expander.".format(cls.__name__))
|
|
166
|
-
ExpanderInfoValidator._add_check_function(cls, _check_attr)
|
|
167
|
-
return cls
|
|
168
|
-
|
|
169
|
-
return wrapper
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
def to_frac_z_axis(ori_shape, ori_axis):
|
|
173
|
-
"""
|
|
174
|
-
judge the format is fractal NZ
|
|
175
|
-
Parameters
|
|
176
|
-
----------
|
|
177
|
-
ori_shape: list or tuple
|
|
178
|
-
original shape of input
|
|
179
|
-
ori_axis: list or tuple
|
|
180
|
-
original axis of original shape to operate
|
|
181
|
-
Returns
|
|
182
|
-
-------
|
|
183
|
-
output: list
|
|
184
|
-
axis of the fractal Nz shape
|
|
185
|
-
"""
|
|
186
|
-
frac_z_axis = list(ori_axis)
|
|
187
|
-
shape_len = len(ori_shape)
|
|
188
|
-
axis_count = len(frac_z_axis)
|
|
189
|
-
axis_negative_1 = shape_len - 1
|
|
190
|
-
axis_negative_2 = shape_len - 2
|
|
191
|
-
for i in range(axis_count):
|
|
192
|
-
axis_index = (frac_z_axis[i] + shape_len) % shape_len
|
|
193
|
-
if axis_index == axis_negative_1:
|
|
194
|
-
if frac_z_axis[i] > shape_len - 2: # akg:[2,3] [1,4] tbe:[2,4] [1,3]
|
|
195
|
-
frac_z_axis[i] = axis_index - 1
|
|
196
|
-
frac_z_axis.append(axis_index + 2)
|
|
197
|
-
else: # no case cover this branch now
|
|
198
|
-
frac_z_axis[i] = axis_index - 1
|
|
199
|
-
frac_z_axis.append(axis_index + 2)
|
|
200
|
-
elif axis_index == axis_negative_2:
|
|
201
|
-
frac_z_axis[i] = axis_index + 1
|
|
202
|
-
frac_z_axis.append(axis_index + 2)
|
|
203
|
-
else:
|
|
204
|
-
frac_z_axis[i] = axis_index
|
|
205
|
-
return frac_z_axis
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
def infer_shape_from_fractalnz(fractal):
|
|
209
|
-
"get original shape from fractalnz shape"
|
|
210
|
-
shape = []
|
|
211
|
-
dims = len(fractal)
|
|
212
|
-
batch = dims - 4
|
|
213
|
-
for i in range(batch):
|
|
214
|
-
shape.append(fractal[i])
|
|
215
|
-
m = fractal[dims - 3] * fractal[dims - 2]
|
|
216
|
-
n = fractal[dims - 4] * fractal[dims - 1]
|
|
217
|
-
shape.append(m)
|
|
218
|
-
shape.append(n)
|
|
219
|
-
return shape
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
def get_reduced_ori_shape(shape, axis):
|
|
223
|
-
"get shape after reduced which is based on original shape"
|
|
224
|
-
reduced_ori_shape = []
|
|
225
|
-
for i, value in enumerate(shape):
|
|
226
|
-
if i in axis:
|
|
227
|
-
reduced_ori_shape.append(1)
|
|
228
|
-
else:
|
|
229
|
-
reduced_ori_shape.append(value)
|
|
230
|
-
return reduced_ori_shape
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
def get_reduce_axis_shape(shape, data_format, axis):
|
|
234
|
-
"""
|
|
235
|
-
Get the reduce axis under format `data_format` and original reduced shape.
|
|
236
|
-
Parameters
|
|
237
|
-
----------
|
|
238
|
-
shape: list or tuple
|
|
239
|
-
shape of input
|
|
240
|
-
data_format: str
|
|
241
|
-
data format of input
|
|
242
|
-
axis: None, int, list or tuple
|
|
243
|
-
reduce axis of the original shape
|
|
244
|
-
Returns
|
|
245
|
-
-------
|
|
246
|
-
reduce_axis: list
|
|
247
|
-
reduce axis of the `data_format` shape
|
|
248
|
-
ori_reduced_shape: list
|
|
249
|
-
original reduced shape
|
|
250
|
-
"""
|
|
251
|
-
ori_shape = shape
|
|
252
|
-
if data_format == "FRACTAL_NZ":
|
|
253
|
-
ori_shape = infer_shape_from_fractalnz(shape)
|
|
254
|
-
if not axis:
|
|
255
|
-
axis = []
|
|
256
|
-
for i, _ in enumerate(ori_shape):
|
|
257
|
-
axis.append(i)
|
|
258
|
-
else:
|
|
259
|
-
if isinstance(axis, int):
|
|
260
|
-
axis = [axis]
|
|
261
|
-
for i, _ in enumerate(list(axis)):
|
|
262
|
-
if axis[i] < 0:
|
|
263
|
-
axis[i] += len(ori_shape)
|
|
264
|
-
|
|
265
|
-
ori_reduced_shape = get_reduced_ori_shape(ori_shape, axis)
|
|
266
|
-
reduce_axis = axis
|
|
267
|
-
if data_format == "FRACTAL_NZ":
|
|
268
|
-
reduce_axis = to_frac_z_axis(ori_shape, axis)
|
|
269
|
-
return reduce_axis, ori_reduced_shape
|
|
@@ -1,33 +0,0 @@
|
|
|
1
|
-
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for addn"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
|
|
17
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@VLD.check_all_formats_same
|
|
21
|
-
class AddN(Expander):
|
|
22
|
-
"""Expand AddN to multiple Adds"""
|
|
23
|
-
|
|
24
|
-
def _check(self):
|
|
25
|
-
if len(self.inputs) < 2:
|
|
26
|
-
raise GKException("For 'AddN', the inputs num should be greater than 1, but got {}"
|
|
27
|
-
.format(len(self.inputs)))
|
|
28
|
-
|
|
29
|
-
def _expand(self, graph_builder):
|
|
30
|
-
result = self.inputs[0]
|
|
31
|
-
for inp in self.inputs[1:]:
|
|
32
|
-
result = graph_builder.emit('Add', [result, inp])
|
|
33
|
-
return result
|
|
@@ -1,152 +0,0 @@
|
|
|
1
|
-
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for BatchNorm"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
17
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
18
|
-
from .expand_dims import ExpandDims
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
@VLD.add_format(DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
|
22
|
-
@VLD.add_format(DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
|
23
|
-
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
|
24
|
-
@VLD.check_attrs('is_training', 'momentum', 'epsilon')
|
|
25
|
-
class BatchNorm(Expander):
|
|
26
|
-
"""BatchNorm expander"""
|
|
27
|
-
|
|
28
|
-
def _expand(self, graph_builder):
|
|
29
|
-
# get op info
|
|
30
|
-
input_x = self.inputs[0]
|
|
31
|
-
input_scale = self.inputs[1]
|
|
32
|
-
input_offset = self.inputs[2]
|
|
33
|
-
input_mean = self.inputs[3]
|
|
34
|
-
input_variance = self.inputs[4]
|
|
35
|
-
epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'])
|
|
36
|
-
|
|
37
|
-
input_x_ori_type = input_x.dtype
|
|
38
|
-
input_x_new_type = input_x.dtype
|
|
39
|
-
if input_x.dtype == "float16" and input_scale.dtype == "float32" and input_offset.dtype == "float32" and \
|
|
40
|
-
input_mean.dtype == "float32" and input_variance.dtype == "float32":
|
|
41
|
-
input_x_new_type = "float32"
|
|
42
|
-
if input_x_new_type != input_x_ori_type:
|
|
43
|
-
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': input_x_new_type})
|
|
44
|
-
|
|
45
|
-
if self.attrs['is_training']:
|
|
46
|
-
self.inputs[0] = input_x
|
|
47
|
-
res_y, mean_res, variance_res, mean_muls, y_sqrt_rec = self._bn_train(graph_builder)
|
|
48
|
-
if input_x_new_type != input_x_ori_type:
|
|
49
|
-
res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
|
|
50
|
-
return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec
|
|
51
|
-
# infer mode
|
|
52
|
-
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
|
53
|
-
input_mean = graph_builder.emit(
|
|
54
|
-
'Reshape', [input_mean], attrs={'shape': ExpandDims.infer_shape(input_mean.shape, [-1, -1])})
|
|
55
|
-
input_scale = graph_builder.emit(
|
|
56
|
-
'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
|
|
57
|
-
input_offset = graph_builder.emit(
|
|
58
|
-
'Reshape', [input_offset], attrs={'shape': ExpandDims.infer_shape(input_offset.shape, [-1, -1])})
|
|
59
|
-
x_sub = graph_builder.emit('Sub', [input_x, input_mean])
|
|
60
|
-
x_sub_mul = graph_builder.emit('Mul', [input_scale, x_sub])
|
|
61
|
-
var_add = graph_builder.emit('Add', [epsilon_v, input_variance])
|
|
62
|
-
var_add_sqrt = graph_builder.emit('Sqrt', [var_add])
|
|
63
|
-
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
|
64
|
-
var_add_sqrt = graph_builder.emit(
|
|
65
|
-
'Reshape', [var_add_sqrt], attrs={'shape': ExpandDims.infer_shape(var_add_sqrt.shape, [-1, -1])})
|
|
66
|
-
x_div = graph_builder.emit('RealDiv', [x_sub_mul, var_add_sqrt])
|
|
67
|
-
res_y = graph_builder.emit('Add', [input_offset, x_div])
|
|
68
|
-
if input_x_new_type != input_x_ori_type:
|
|
69
|
-
res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
|
|
70
|
-
return res_y, var_add, var_add, var_add, var_add
|
|
71
|
-
|
|
72
|
-
def _bn_train(self, graph_builder):
|
|
73
|
-
"""expand BatchNorm for training mode"""
|
|
74
|
-
input_x = self.inputs[0]
|
|
75
|
-
input_scale = self.inputs[1]
|
|
76
|
-
input_offset = self.inputs[2]
|
|
77
|
-
input_mean = self.inputs[3]
|
|
78
|
-
input_variance = self.inputs[4]
|
|
79
|
-
epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'])
|
|
80
|
-
reduce_axis = ()
|
|
81
|
-
shape_x = input_x.shape
|
|
82
|
-
if input_x.data_format == DF.NHWC:
|
|
83
|
-
reduce_axis = (0, 1, 2)
|
|
84
|
-
num = shape_x[0] * shape_x[1] * shape_x[2]
|
|
85
|
-
else:
|
|
86
|
-
reduce_axis = (0, 2, 3)
|
|
87
|
-
num = shape_x[0] * shape_x[2] * shape_x[3]
|
|
88
|
-
num_rec = 1.0 / num
|
|
89
|
-
num_rec_v = graph_builder.value(input_scale.dtype, num_rec)
|
|
90
|
-
|
|
91
|
-
# compute mean value of input_x
|
|
92
|
-
mean_sum = graph_builder.emit(
|
|
93
|
-
'ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
|
|
94
|
-
mean_muls = graph_builder.emit('Mul', [mean_sum, num_rec_v])
|
|
95
|
-
|
|
96
|
-
# compute variance of input_x
|
|
97
|
-
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
|
98
|
-
mean_muls_expand = graph_builder.emit(
|
|
99
|
-
'Reshape', [mean_muls], attrs={'shape': ExpandDims.infer_shape(mean_muls.shape, [-1, -1])})
|
|
100
|
-
else:
|
|
101
|
-
mean_muls_expand = mean_muls
|
|
102
|
-
var_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
|
|
103
|
-
var_mul = graph_builder.emit('Mul', [var_sub, var_sub])
|
|
104
|
-
var_sum = graph_builder.emit('ReduceSum', [var_mul], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
|
|
105
|
-
var_mul = graph_builder.emit('Mul', [var_sum, num_rec_v])
|
|
106
|
-
|
|
107
|
-
# y_sqrt_rec means 1 / sqrt(variance + epsilon), which is calculated in backward pass
|
|
108
|
-
scalar_one = 1.0
|
|
109
|
-
scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one)
|
|
110
|
-
y_add = graph_builder.emit('Add', [var_mul, epsilon_v])
|
|
111
|
-
y_sqrt = graph_builder.emit('Sqrt', [y_add])
|
|
112
|
-
y_sqrt_rec = graph_builder.emit('RealDiv', [scalar_one_v, y_sqrt])
|
|
113
|
-
|
|
114
|
-
# compute res_y
|
|
115
|
-
tmp_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
|
|
116
|
-
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
|
117
|
-
y_sqrt_rec_expand = graph_builder.emit(
|
|
118
|
-
'Reshape', [y_sqrt_rec], attrs={'shape': ExpandDims.infer_shape(y_sqrt_rec.shape, [-1, -1])})
|
|
119
|
-
else:
|
|
120
|
-
y_sqrt_rec_expand = y_sqrt_rec
|
|
121
|
-
y_norm = graph_builder.emit('Mul', [tmp_sub, y_sqrt_rec_expand])
|
|
122
|
-
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
|
123
|
-
input_scale_expand = graph_builder.emit(
|
|
124
|
-
'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
|
|
125
|
-
else:
|
|
126
|
-
input_scale_expand = input_scale
|
|
127
|
-
res_y_mul = graph_builder.emit('Mul', [input_scale_expand, y_norm])
|
|
128
|
-
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
|
129
|
-
input_offset_expand = graph_builder.emit(
|
|
130
|
-
'Reshape', [input_offset], attrs={'shape': ExpandDims.infer_shape(input_offset.shape, [-1, -1])})
|
|
131
|
-
else:
|
|
132
|
-
input_offset_expand = input_offset
|
|
133
|
-
res_y = graph_builder.emit('Add', [res_y_mul, input_offset_expand])
|
|
134
|
-
|
|
135
|
-
# compute mean_res
|
|
136
|
-
momentum_sub = scalar_one - self.attrs['momentum']
|
|
137
|
-
momentum_v_sub = graph_builder.value(input_scale.dtype, momentum_sub)
|
|
138
|
-
new_running_mean_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_mean])
|
|
139
|
-
momentum_v = graph_builder.value(input_scale.dtype, self.attrs['momentum'])
|
|
140
|
-
current_mean_tmp = graph_builder.emit('Mul', [momentum_v, mean_muls])
|
|
141
|
-
updated_moving_mean = graph_builder.emit('Add', [new_running_mean_tmp, current_mean_tmp])
|
|
142
|
-
mean_res = graph_builder.emit('Assign', [input_mean, updated_moving_mean])
|
|
143
|
-
|
|
144
|
-
# variance_res is calculated by sample variance, and need to multiply by num / (num - 1)
|
|
145
|
-
var_num = float(num) / (num - 1)
|
|
146
|
-
var_num_v = graph_builder.value(input_scale.dtype, var_num)
|
|
147
|
-
var_mul_update = graph_builder.emit('Mul', [var_num_v, var_mul])
|
|
148
|
-
new_running_var_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_variance])
|
|
149
|
-
current_var_tmp = graph_builder.emit('Mul', [momentum_v, var_mul_update])
|
|
150
|
-
updated_moving_variance = graph_builder.emit('Add', [new_running_var_tmp, current_var_tmp])
|
|
151
|
-
variance_res = graph_builder.emit('Assign', [input_variance, updated_moving_variance])
|
|
152
|
-
return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec
|
|
@@ -1,105 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for BatchNormGrad"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
17
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
18
|
-
from .expand_dims import ExpandDims
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
@VLD.add_format(DF.NHWC, DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
|
22
|
-
@VLD.add_format(DF.NCHW, DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
|
23
|
-
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
|
24
|
-
@VLD.check_attrs('is_training', 'epsilon')
|
|
25
|
-
class BatchNormGrad(Expander):
|
|
26
|
-
"""BatchNormGrad expander"""
|
|
27
|
-
|
|
28
|
-
def _expand(self, graph_builder):
|
|
29
|
-
# get op info
|
|
30
|
-
input_dy = self.inputs[0]
|
|
31
|
-
input_x = self.inputs[1]
|
|
32
|
-
input_scale = self.inputs[2]
|
|
33
|
-
input_save_mean = self.inputs[3]
|
|
34
|
-
input_save_inv_variance = self.inputs[4]
|
|
35
|
-
|
|
36
|
-
reduce_axis = ()
|
|
37
|
-
shape_x = input_x.shape
|
|
38
|
-
if input_x.data_format == DF.NHWC:
|
|
39
|
-
reduce_axis = (0, 1, 2)
|
|
40
|
-
num = shape_x[0] * shape_x[1] * shape_x[2]
|
|
41
|
-
else:
|
|
42
|
-
reduce_axis = (0, 2, 3)
|
|
43
|
-
num = shape_x[0] * shape_x[2] * shape_x[3]
|
|
44
|
-
ori_type = input_x.dtype
|
|
45
|
-
if ori_type == 'float16':
|
|
46
|
-
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'})
|
|
47
|
-
if input_dy.dtype == 'float16':
|
|
48
|
-
input_dy = graph_builder.emit('Cast', [input_dy], attrs={'dst_type': 'float32'})
|
|
49
|
-
num_rec = -1.0 / num
|
|
50
|
-
num_rec_v = graph_builder.value(input_scale.dtype, num_rec)
|
|
51
|
-
dbeta = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
|
|
52
|
-
|
|
53
|
-
# in training input_save_inv_variance means 1 / sqrt(variance + epsilon), which is calculated in forward pass
|
|
54
|
-
if self.attrs['is_training']:
|
|
55
|
-
inv_variance = input_save_inv_variance
|
|
56
|
-
else:
|
|
57
|
-
epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'])
|
|
58
|
-
var_add = graph_builder.emit('Add', [input_save_inv_variance, epsilon_v])
|
|
59
|
-
sqrt_var_eps = graph_builder.emit('Sqrt', [var_add])
|
|
60
|
-
scalar_one = 1.0
|
|
61
|
-
scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one)
|
|
62
|
-
inv_variance = graph_builder.emit('RealDiv', [scalar_one_v, sqrt_var_eps])
|
|
63
|
-
|
|
64
|
-
# compute dgamma
|
|
65
|
-
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
|
66
|
-
input_save_mean = graph_builder.emit(
|
|
67
|
-
'Reshape', [input_save_mean], attrs={'shape': ExpandDims.infer_shape(input_save_mean.shape, [-1, -1])})
|
|
68
|
-
inv_variance = graph_builder.emit(
|
|
69
|
-
'Reshape', [inv_variance], attrs={'shape': ExpandDims.infer_shape(inv_variance.shape, [-1, -1])})
|
|
70
|
-
input_scale = graph_builder.emit(
|
|
71
|
-
'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
|
|
72
|
-
x_sub_mean = graph_builder.emit('Sub', [input_x, input_save_mean])
|
|
73
|
-
x_div = graph_builder.emit('Mul', [x_sub_mean, inv_variance])
|
|
74
|
-
dgamma_param = graph_builder.emit('Mul', [input_dy, x_div])
|
|
75
|
-
dgamma = graph_builder.emit(
|
|
76
|
-
'ReduceSum', [dgamma_param], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
|
|
77
|
-
|
|
78
|
-
# compute dx
|
|
79
|
-
if self.attrs['is_training']:
|
|
80
|
-
tmp_b = graph_builder.emit('Mul', [num_rec_v, dbeta])
|
|
81
|
-
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
|
82
|
-
dgamma_expand = graph_builder.emit(
|
|
83
|
-
'Reshape', [dgamma], attrs={'shape': ExpandDims.infer_shape(dgamma.shape, [-1, -1])})
|
|
84
|
-
tmp_b = graph_builder.emit(
|
|
85
|
-
'Reshape', [tmp_b], attrs={'shape': ExpandDims.infer_shape(tmp_b.shape, [-1, -1])})
|
|
86
|
-
else:
|
|
87
|
-
dgamma_expand = dgamma
|
|
88
|
-
x_sub_mean_dgamma_mul = graph_builder.emit('Mul', [x_div, dgamma_expand])
|
|
89
|
-
tmp_c = graph_builder.emit('Mul', [num_rec_v, x_sub_mean_dgamma_mul])
|
|
90
|
-
tmp_ab_add = graph_builder.emit('Add', [input_dy, tmp_b])
|
|
91
|
-
tmp_abc_add = graph_builder.emit('Add', [tmp_ab_add, tmp_c])
|
|
92
|
-
gamma_mul = graph_builder.emit('Mul', [input_scale, tmp_abc_add])
|
|
93
|
-
dx = graph_builder.emit('Mul', [inv_variance, gamma_mul])
|
|
94
|
-
else:
|
|
95
|
-
y_scale = graph_builder.emit('Mul', [input_scale, input_dy])
|
|
96
|
-
dx = graph_builder.emit('Mul', [inv_variance, y_scale])
|
|
97
|
-
if ori_type == 'float16':
|
|
98
|
-
dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'})
|
|
99
|
-
|
|
100
|
-
# set output tensors' data_format
|
|
101
|
-
dx.data_format = self.outputs[0]['format']
|
|
102
|
-
dgamma.data_format = self.outputs[1]['format']
|
|
103
|
-
dbeta.data_format = self.outputs[2]['format']
|
|
104
|
-
|
|
105
|
-
return dx, dgamma, dbeta
|
|
@@ -1,33 +0,0 @@
|
|
|
1
|
-
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for ClipByNormNoDivSum"""
|
|
16
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@VLD.check_all_formats_same
|
|
20
|
-
class ClipByNormNoDivSum(Expander):
|
|
21
|
-
"""ClipByNormNoDivSum expander"""
|
|
22
|
-
|
|
23
|
-
def _expand(self, graph_builder):
|
|
24
|
-
input_x0, input_x1, input_x2, input_x3 = self.inputs
|
|
25
|
-
|
|
26
|
-
# cal result
|
|
27
|
-
greater_res = graph_builder.emit('Greater', [input_x0, input_x1])
|
|
28
|
-
select_res0 = graph_builder.emit('Select', [greater_res, input_x0, input_x2])
|
|
29
|
-
sqrt_res = graph_builder.emit('Sqrt', [select_res0])
|
|
30
|
-
select_res1 = graph_builder.emit('Select', [greater_res, sqrt_res, input_x0])
|
|
31
|
-
result = graph_builder.emit('Maximum', [select_res1, input_x3])
|
|
32
|
-
|
|
33
|
-
return result
|
|
@@ -1,30 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for cabs"""
|
|
16
|
-
from mindspore._extends.graph_kernel.expanders._utils import Expander
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class CAbs(Expander):
|
|
20
|
-
"""CAbs expander"""
|
|
21
|
-
|
|
22
|
-
def _expand(self, graph_builder):
|
|
23
|
-
input_x = self.inputs[0]
|
|
24
|
-
x_real = graph_builder.emit('CReal', [input_x])
|
|
25
|
-
x_imag = graph_builder.emit('CImag', [input_x])
|
|
26
|
-
squre_x_real = graph_builder.emit('Mul', [x_real, x_real])
|
|
27
|
-
squre_x_imag = graph_builder.emit('Mul', [x_imag, x_imag])
|
|
28
|
-
squre_sum = graph_builder.emit('Add', [squre_x_real, squre_x_imag])
|
|
29
|
-
result = graph_builder.emit('Sqrt', [squre_sum])
|
|
30
|
-
return result
|
|
@@ -1,44 +0,0 @@
|
|
|
1
|
-
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for cadd"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
17
|
-
from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
|
|
21
|
-
class CAdd(Expander):
|
|
22
|
-
"""CAdd expander"""
|
|
23
|
-
|
|
24
|
-
def _expand(self, graph_builder):
|
|
25
|
-
input_x, input_y = self.inputs
|
|
26
|
-
if input_x.dtype == input_y.dtype:
|
|
27
|
-
x_real = graph_builder.emit('CReal', [input_x])
|
|
28
|
-
y_real = graph_builder.emit('CReal', [input_y])
|
|
29
|
-
x_imag = graph_builder.emit('CImag', [input_x])
|
|
30
|
-
y_imag = graph_builder.emit('CImag', [input_y])
|
|
31
|
-
result_real = graph_builder.emit('Add', [x_real, y_real])
|
|
32
|
-
result_imag = graph_builder.emit('Add', [x_imag, y_imag])
|
|
33
|
-
result = graph_builder.emit('Complex', [result_real, result_imag])
|
|
34
|
-
elif input_x.dtype == "complex64" or input_x.dtype == "complex128":
|
|
35
|
-
x_real = graph_builder.emit('CReal', [input_x])
|
|
36
|
-
x_imag = graph_builder.emit('CImag', [input_x])
|
|
37
|
-
x_real_add_y = graph_builder.emit('Add', [x_real, input_y])
|
|
38
|
-
result = graph_builder.emit('Complex', [x_real_add_y, x_imag])
|
|
39
|
-
elif input_y.dtype == "complex64" or input_y.dtype == "complex128":
|
|
40
|
-
y_real = graph_builder.emit('CReal', [input_y])
|
|
41
|
-
y_imag = graph_builder.emit('CImag', [input_y])
|
|
42
|
-
y_real_add_x = graph_builder.emit('Add', [y_real, input_x])
|
|
43
|
-
result = graph_builder.emit('Complex', [y_real_add_x, y_imag])
|
|
44
|
-
return result
|