mindspore 2.1.0__cp37-cp37m-manylinux1_x86_64.whl → 2.2.10__cp37-cp37m-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +46 -19
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/ascend_profilier/__init__.py +0 -0
- mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
- mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +98 -274
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
- mindspore/_akg/akg/utils/util.py +38 -0
- mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +12 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +61 -71
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +74 -104
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +13 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +28 -5
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8928 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +141 -88
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +84 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +28 -19
- mindspore/ops/operations/_grad_ops.py +72 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +189 -141
- mindspore/ops/operations/nn_ops.py +794 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +6 -7
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +488 -528
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
|
@@ -1,47 +0,0 @@
|
|
|
1
|
-
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for fused_adam_weight_decay"""
|
|
16
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@VLD.check_all_formats_same
|
|
20
|
-
class FusedAdamWeightDecay(Expander):
|
|
21
|
-
"""FusedAdamWeightDecay expander"""
|
|
22
|
-
|
|
23
|
-
def _expand(self, graph_builder):
|
|
24
|
-
beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient, weight_decay = self.inputs
|
|
25
|
-
|
|
26
|
-
# compute result
|
|
27
|
-
beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m])
|
|
28
|
-
one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient])
|
|
29
|
-
next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad])
|
|
30
|
-
beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v])
|
|
31
|
-
grad_square = graph_builder.emit('Mul', [gradient, gradient])
|
|
32
|
-
one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square])
|
|
33
|
-
next_v = graph_builder.emit('Add', [beta_2_mul_v, one_sub_beta_2_mul_grad_square])
|
|
34
|
-
sqrt_next_v = graph_builder.emit('Sqrt', [next_v])
|
|
35
|
-
sqrt_next_v_add_eps = graph_builder.emit('Add', [sqrt_next_v, eps])
|
|
36
|
-
update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps])
|
|
37
|
-
param_with_weight_decay = graph_builder.emit('Mul', [weight_decay, param])
|
|
38
|
-
update = graph_builder.emit('Add', [update, param_with_weight_decay])
|
|
39
|
-
update_with_lr = graph_builder.emit('Mul', [lr, update])
|
|
40
|
-
next_para = graph_builder.emit('Sub', [param, update_with_lr])
|
|
41
|
-
|
|
42
|
-
para_result = graph_builder.emit(
|
|
43
|
-
'InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True})
|
|
44
|
-
para_result = graph_builder.emit('InplaceAssign', [m, next_m, para_result], attrs={'fake_output': True})
|
|
45
|
-
para_result = graph_builder.emit('InplaceAssign', [v, next_v, para_result], attrs={'fake_output': True})
|
|
46
|
-
|
|
47
|
-
return para_result
|
|
@@ -1,28 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for FusedMulAdd"""
|
|
16
|
-
from ._utils import Expander
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class FusedMulAdd(Expander):
|
|
20
|
-
"""FusedMulAdd expander"""
|
|
21
|
-
|
|
22
|
-
def _expand(self, graph_builder):
|
|
23
|
-
input_x, input_y, input_z = self.inputs
|
|
24
|
-
|
|
25
|
-
mul_res = graph_builder.emit('Mul', [input_x, input_y])
|
|
26
|
-
result = graph_builder.emit('Add', [mul_res, input_z])
|
|
27
|
-
|
|
28
|
-
return result
|
|
@@ -1,70 +0,0 @@
|
|
|
1
|
-
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for gelugrad"""
|
|
16
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@VLD.check_all_formats_same
|
|
20
|
-
class GeLUGrad(Expander):
|
|
21
|
-
"""GeLUGrad expander"""
|
|
22
|
-
CSVALUE = 0.044715
|
|
23
|
-
CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi)
|
|
24
|
-
CSVALUE_TRI = 0.134141 # CSVALUE * 3
|
|
25
|
-
|
|
26
|
-
def _expand(self, graph_builder):
|
|
27
|
-
# cal formula are:
|
|
28
|
-
# gelu_grad of dy and x is dy * y'
|
|
29
|
-
# y' is 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right
|
|
30
|
-
# tanh_para is 'sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)'
|
|
31
|
-
# mul_right is 'sqrt(2.0 / pi) * (1 + 3 * 0.044715 * x * x)'
|
|
32
|
-
|
|
33
|
-
input_dy, input_x, _ = self.inputs
|
|
34
|
-
|
|
35
|
-
# create some const var
|
|
36
|
-
const_csvalue = graph_builder.value(input_dy.dtype, self.CSVALUE)
|
|
37
|
-
const_csvalue_sqrt_two_div_pi = graph_builder.value(input_dy.dtype, self.CSVALUE_SQRT_TWO_DIV_PI)
|
|
38
|
-
const_csvalue_tri = graph_builder.value(input_dy.dtype, self.CSVALUE_TRI)
|
|
39
|
-
const_one = graph_builder.value(input_dy.dtype, 1)
|
|
40
|
-
const_half = graph_builder.value(input_dy.dtype, 0.5)
|
|
41
|
-
|
|
42
|
-
# cal mul_right
|
|
43
|
-
mul_double = graph_builder.emit('Mul', [input_x, input_x])
|
|
44
|
-
mul_double_mul_tri = graph_builder.emit('Mul', [const_csvalue_tri, mul_double])
|
|
45
|
-
mul_add_one = graph_builder.emit('Add', [const_one, mul_double_mul_tri])
|
|
46
|
-
mul_right = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_one])
|
|
47
|
-
|
|
48
|
-
# cal tanh_para
|
|
49
|
-
mul_triple = graph_builder.emit('Mul', [input_x, mul_double])
|
|
50
|
-
mul_triple_mul_csvalue = graph_builder.emit('Mul', [const_csvalue, mul_triple])
|
|
51
|
-
mul_add_x = graph_builder.emit('Add', [input_x, mul_triple_mul_csvalue])
|
|
52
|
-
tanh_para = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_x])
|
|
53
|
-
|
|
54
|
-
# cal 0.5 * (1.0 + tanh(tahn_para))
|
|
55
|
-
tanh_res = graph_builder.emit('Tanh', [tanh_para])
|
|
56
|
-
tanh_res_add_one = graph_builder.emit('Add', [const_one, tanh_res])
|
|
57
|
-
half_mul_tanh_res_add_one = graph_builder.emit('Mul', [const_half, tanh_res_add_one])
|
|
58
|
-
|
|
59
|
-
# cal 0.5 * x * (1.0 - tanh(tanh_para) * tanh(tanh_para)) * mul_right
|
|
60
|
-
tan_res_double = graph_builder.emit('Mul', [tanh_res, tanh_res])
|
|
61
|
-
one_sub_tan_res_double = graph_builder.emit('Sub', [const_one, tan_res_double])
|
|
62
|
-
half_mul_x = graph_builder.emit('Mul', [const_half, input_x])
|
|
63
|
-
mul_tmp = graph_builder.emit('Mul', [half_mul_x, one_sub_tan_res_double])
|
|
64
|
-
mul_final = graph_builder.emit('Mul', [mul_tmp, mul_right])
|
|
65
|
-
|
|
66
|
-
# cal result
|
|
67
|
-
result_tmp = graph_builder.emit('Add', [half_mul_tanh_res_add_one, mul_final])
|
|
68
|
-
result = graph_builder.emit('Mul', [input_dy, result_tmp])
|
|
69
|
-
|
|
70
|
-
return result
|
|
@@ -1,40 +0,0 @@
|
|
|
1
|
-
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for GkDropout"""
|
|
16
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@VLD.check_all_formats_same
|
|
20
|
-
@VLD.check_attrs('keep_prob')
|
|
21
|
-
class GkDropout(Expander):
|
|
22
|
-
"""GkDropout expander"""
|
|
23
|
-
|
|
24
|
-
def _expand(self, graph_builder):
|
|
25
|
-
input_x, input_mask = self.inputs
|
|
26
|
-
keep_prob = self.attrs['keep_prob']
|
|
27
|
-
|
|
28
|
-
r_keep_prob = graph_builder.value(input_x.dtype, 1.0 / keep_prob)
|
|
29
|
-
keep_prob = graph_builder.value(input_x.dtype, keep_prob)
|
|
30
|
-
|
|
31
|
-
if input_mask.dtype != input_x.dtype:
|
|
32
|
-
input_mask = graph_builder.emit('Cast', [input_mask], attrs={'dst_type': input_x.dtype})
|
|
33
|
-
mask = graph_builder.emit('LessEqual', [input_mask, keep_prob]) # output is bool type
|
|
34
|
-
mask = graph_builder.emit('Cast', [mask], attrs={'dst_type': input_x.dtype})
|
|
35
|
-
|
|
36
|
-
# compute result
|
|
37
|
-
result = graph_builder.emit('Mul', [r_keep_prob, input_x])
|
|
38
|
-
result = graph_builder.emit('Mul', [result, mask])
|
|
39
|
-
|
|
40
|
-
return result, mask
|
|
@@ -1,25 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for Identity"""
|
|
16
|
-
from ._utils import Expander
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class Identity(Expander):
|
|
20
|
-
"""Identity expander"""
|
|
21
|
-
|
|
22
|
-
def _expand(self, graph_builder):
|
|
23
|
-
input_x = self.inputs[0]
|
|
24
|
-
result = graph_builder.emit('Reshape', [input_x], attrs={'shape': input_x.shape})
|
|
25
|
-
return result
|
|
@@ -1,93 +0,0 @@
|
|
|
1
|
-
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for LayerNorm"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
17
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
18
|
-
from ._utils import infer_shape_from_fractalnz, get_reduced_ori_shape, to_frac_z_axis
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
@VLD.add_format(DF.FRAC_NZ, DF.DEFAULT, DF.DEFAULT)
|
|
22
|
-
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
|
23
|
-
@VLD.check_attrs('begin_norm_axis', 'begin_params_axis', 'epsilon')
|
|
24
|
-
class LayerNorm(Expander):
|
|
25
|
-
"""LayerNorm expander"""
|
|
26
|
-
|
|
27
|
-
def _expand(self, graph_builder):
|
|
28
|
-
input_x, input_gamma, input_beta = self.inputs
|
|
29
|
-
processor = self.processor
|
|
30
|
-
begin_norm_axis = self.attrs['begin_norm_axis']
|
|
31
|
-
epsilon = self.attrs['epsilon']
|
|
32
|
-
|
|
33
|
-
ori_dtype = input_x.dtype
|
|
34
|
-
if processor == 'aicore' and ori_dtype == 'float16':
|
|
35
|
-
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'})
|
|
36
|
-
input_gamma = graph_builder.emit('Cast', [input_gamma], attrs={'dst_type': 'float32'})
|
|
37
|
-
input_beta = graph_builder.emit('Cast', [input_beta], attrs={'dst_type': 'float32'})
|
|
38
|
-
|
|
39
|
-
ori_shape_x = input_x.shape
|
|
40
|
-
if input_x.data_format == DF.FRAC_NZ:
|
|
41
|
-
ori_shape_x = infer_shape_from_fractalnz(input_x.shape)
|
|
42
|
-
|
|
43
|
-
# Calculate the scaling ratio of the average
|
|
44
|
-
if begin_norm_axis < 0:
|
|
45
|
-
begin_norm_axis += len(ori_shape_x)
|
|
46
|
-
|
|
47
|
-
reduce_axis = ()
|
|
48
|
-
for i, _ in enumerate(ori_shape_x):
|
|
49
|
-
if i > begin_norm_axis or i == begin_norm_axis:
|
|
50
|
-
reduce_axis = reduce_axis + (i,)
|
|
51
|
-
|
|
52
|
-
reduce_elts = 1.0
|
|
53
|
-
for i in reduce_axis:
|
|
54
|
-
reduce_elts *= ori_shape_x[i]
|
|
55
|
-
# after reduced
|
|
56
|
-
ori_reduced_shape_x = get_reduced_ori_shape(ori_shape_x, reduce_axis)
|
|
57
|
-
|
|
58
|
-
axis = reduce_axis
|
|
59
|
-
if input_x.data_format == DF.FRAC_NZ:
|
|
60
|
-
axis = to_frac_z_axis(ori_shape_x, reduce_axis)
|
|
61
|
-
|
|
62
|
-
mean_cof_v = graph_builder.value(input_x.dtype, 1.0 / reduce_elts)
|
|
63
|
-
|
|
64
|
-
# Calculate mean
|
|
65
|
-
mean_red = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
|
|
66
|
-
mean = graph_builder.emit('Mul', [mean_red, mean_cof_v])
|
|
67
|
-
if input_x.data_format == DF.FRAC_NZ:
|
|
68
|
-
mean = graph_builder.emit('Reshape', [mean], attrs={'shape': ori_reduced_shape_x})
|
|
69
|
-
|
|
70
|
-
# Calculate variance
|
|
71
|
-
variance_sub = graph_builder.emit('Sub', [input_x, mean])
|
|
72
|
-
variance_mul = graph_builder.emit('Mul', [variance_sub, variance_sub])
|
|
73
|
-
variance_red = graph_builder.emit('ReduceSum', [variance_mul], attrs={'reduce_axis': axis, 'keep_dims': True})
|
|
74
|
-
variance = graph_builder.emit('Mul', [variance_red, mean_cof_v])
|
|
75
|
-
if input_x.data_format == DF.FRAC_NZ:
|
|
76
|
-
variance = graph_builder.emit('Reshape', [variance], attrs={'shape': ori_reduced_shape_x})
|
|
77
|
-
|
|
78
|
-
# Calculate normalize
|
|
79
|
-
normalize_sub = graph_builder.emit('Sub', [input_x, mean])
|
|
80
|
-
epsilon_v = graph_builder.value(input_x.dtype, epsilon)
|
|
81
|
-
normalize_add = graph_builder.emit('Add', [variance, epsilon_v])
|
|
82
|
-
normlize_rsqrt = graph_builder.emit('Rsqrt', [normalize_add])
|
|
83
|
-
normalize_mul = graph_builder.emit('Mul', [normalize_sub, normlize_rsqrt])
|
|
84
|
-
|
|
85
|
-
# Calculate scale and translate
|
|
86
|
-
scale_mul = graph_builder.emit('Mul', [normalize_mul, input_gamma])
|
|
87
|
-
res = graph_builder.emit('Add', [scale_mul, input_beta])
|
|
88
|
-
|
|
89
|
-
if processor == 'aicore' and ori_dtype == 'float16':
|
|
90
|
-
res = graph_builder.emit('Cast', [res], attrs={'dst_type': 'float16'})
|
|
91
|
-
mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float16'})
|
|
92
|
-
variance = graph_builder.emit('Cast', [variance], attrs={'dst_type': 'float16'})
|
|
93
|
-
return res, mean, variance
|
|
@@ -1,113 +0,0 @@
|
|
|
1
|
-
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for LayerNormGrad"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
17
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
|
21
|
-
@VLD.check_attrs('begin_norm_axis', 'begin_params_axis')
|
|
22
|
-
class LayerNormGrad(Expander):
|
|
23
|
-
"""LayerNormGrad expander"""
|
|
24
|
-
|
|
25
|
-
def _expand(self, graph_builder):
|
|
26
|
-
x, dy, variance, mean, gamma = self.inputs
|
|
27
|
-
processor = self.processor
|
|
28
|
-
begin_norm_axis = self.attrs['begin_norm_axis']
|
|
29
|
-
begin_params_axis = self.attrs['begin_params_axis']
|
|
30
|
-
epsilon = self.attrs['epsilon'] if 'epsilon' in self.attrs else 1e-12
|
|
31
|
-
|
|
32
|
-
ori_dtype = x.dtype
|
|
33
|
-
if processor == 'aicore' and ori_dtype == 'float16':
|
|
34
|
-
x = graph_builder.emit('Cast', [x], attrs={'dst_type': 'float32'})
|
|
35
|
-
dy = graph_builder.emit('Cast', [dy], attrs={'dst_type': 'float32'})
|
|
36
|
-
variance = graph_builder.emit('Cast', [variance], attrs={'dst_type': 'float32'})
|
|
37
|
-
mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float32'})
|
|
38
|
-
gamma = graph_builder.emit('Cast', [gamma], attrs={'dst_type': 'float32'})
|
|
39
|
-
|
|
40
|
-
if begin_norm_axis < 0:
|
|
41
|
-
begin_norm_axis += len(x.shape)
|
|
42
|
-
if begin_params_axis < 0:
|
|
43
|
-
begin_params_axis += len(x.shape)
|
|
44
|
-
|
|
45
|
-
norm_axis = tuple(range(begin_norm_axis, len(x.shape)))
|
|
46
|
-
param_axis = tuple(range(0, begin_params_axis))
|
|
47
|
-
|
|
48
|
-
reduce_size = 1.0
|
|
49
|
-
for i in norm_axis:
|
|
50
|
-
reduce_size *= x.shape[i]
|
|
51
|
-
|
|
52
|
-
# set some constant val.
|
|
53
|
-
eps = graph_builder.value(x.dtype, epsilon)
|
|
54
|
-
const_neg_half = graph_builder.value(x.dtype, -0.5)
|
|
55
|
-
const_neg_two = graph_builder.value(x.dtype, -2.0)
|
|
56
|
-
const_two = graph_builder.value(x.dtype, 2.0)
|
|
57
|
-
const_neg_one = graph_builder.value(x.dtype, -1.0)
|
|
58
|
-
mean_cof = graph_builder.value(x.dtype, (1.0 / reduce_size))
|
|
59
|
-
|
|
60
|
-
# cal dg db
|
|
61
|
-
var_eps = graph_builder.emit('Add', [variance, eps])
|
|
62
|
-
var_eps_log = graph_builder.emit('Log', [var_eps])
|
|
63
|
-
var_eps_mul = graph_builder.emit('Mul', [var_eps_log, const_neg_half])
|
|
64
|
-
rsqrt_var_eps = graph_builder.emit('Exp', [var_eps_mul])
|
|
65
|
-
|
|
66
|
-
x_sub_mean = graph_builder.emit('Sub', [x, mean])
|
|
67
|
-
|
|
68
|
-
x_sub_mean_mul_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, x_sub_mean])
|
|
69
|
-
dg_mul = graph_builder.emit('Mul', [dy, x_sub_mean_mul_rsqrt_var_eps])
|
|
70
|
-
dg = graph_builder.emit('ReduceSum', [dg_mul], attrs={'reduce_axis': param_axis, 'keep_dims': False})
|
|
71
|
-
db = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': param_axis, 'keep_dims': False})
|
|
72
|
-
|
|
73
|
-
# pd_var
|
|
74
|
-
tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, rsqrt_var_eps])
|
|
75
|
-
r_tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, tmp_var_eps])
|
|
76
|
-
|
|
77
|
-
dy_mul_gamma = graph_builder.emit('Mul', [dy, gamma])
|
|
78
|
-
tmp_mul = graph_builder.emit('Mul', [dy_mul_gamma, x_sub_mean])
|
|
79
|
-
padvar_mul1 = graph_builder.emit('ReduceSum', [tmp_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
|
80
|
-
padvar_mul3 = graph_builder.emit('Mul', [padvar_mul1, r_tmp_var_eps])
|
|
81
|
-
pd_var = graph_builder.emit('Mul', [padvar_mul3, const_neg_half])
|
|
82
|
-
|
|
83
|
-
# pd_mean
|
|
84
|
-
pdmean1_sum = graph_builder.emit('ReduceSum', [dy_mul_gamma],
|
|
85
|
-
attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
|
86
|
-
neg_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, const_neg_one])
|
|
87
|
-
pd_mean_1 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, pdmean1_sum])
|
|
88
|
-
|
|
89
|
-
pdmean2_mul1 = graph_builder.emit('Mul', [const_neg_two, x_sub_mean])
|
|
90
|
-
pdmean2_sum = graph_builder.emit('ReduceSum', [pdmean2_mul1],
|
|
91
|
-
attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
|
92
|
-
pdmean2_mul3 = graph_builder.emit('Mul', [pdmean2_sum, mean_cof])
|
|
93
|
-
pd_mean_2 = graph_builder.emit('Mul', [pdmean2_mul3, pd_var])
|
|
94
|
-
|
|
95
|
-
pd_mean = graph_builder.emit('Add', [pd_mean_1, pd_mean_2])
|
|
96
|
-
|
|
97
|
-
# cal dx
|
|
98
|
-
pd_x_1 = graph_builder.emit('Mul', [dy_mul_gamma, rsqrt_var_eps])
|
|
99
|
-
|
|
100
|
-
pdx2_mul = graph_builder.emit('Mul', [pd_var, x_sub_mean])
|
|
101
|
-
pdx2_mul_two = graph_builder.emit('Mul', [pdx2_mul, const_two])
|
|
102
|
-
pd_x_2 = graph_builder.emit('Mul', [pdx2_mul_two, mean_cof])
|
|
103
|
-
|
|
104
|
-
pd_x_3 = graph_builder.emit('Mul', [pd_mean, mean_cof])
|
|
105
|
-
|
|
106
|
-
dx_tmp = graph_builder.emit('Add', [pd_x_1, pd_x_2])
|
|
107
|
-
dx = graph_builder.emit('Add', [dx_tmp, pd_x_3])
|
|
108
|
-
|
|
109
|
-
if processor == 'aicore' and ori_dtype == 'float16':
|
|
110
|
-
dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'})
|
|
111
|
-
dg = graph_builder.emit('Cast', [dg], attrs={'dst_type': 'float16'})
|
|
112
|
-
db = graph_builder.emit('Cast', [db], attrs={'dst_type': 'float16'})
|
|
113
|
-
return dx, dg, db
|
|
@@ -1,46 +0,0 @@
|
|
|
1
|
-
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for LogSoftmax"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
17
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@VLD.add_format(DF.DEFAULT)
|
|
21
|
-
@VLD.check_attrs('axis')
|
|
22
|
-
class LogSoftmax(Expander):
|
|
23
|
-
"""LogSoftmax expander"""
|
|
24
|
-
|
|
25
|
-
def _expand(self, graph_builder):
|
|
26
|
-
input_x = self.inputs[0]
|
|
27
|
-
axis = self.attrs['axis']
|
|
28
|
-
processor = self.processor
|
|
29
|
-
|
|
30
|
-
if isinstance(axis, int):
|
|
31
|
-
axis = (axis,)
|
|
32
|
-
|
|
33
|
-
ori_dtype = input_x.dtype
|
|
34
|
-
if ori_dtype != "float16" and processor == "aicore":
|
|
35
|
-
input_x_f16 = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float16'})
|
|
36
|
-
max_x_f16 = graph_builder.emit('ReduceMax', [input_x_f16], attrs={'reduce_axis': axis, 'keep_dims': True})
|
|
37
|
-
max_x = graph_builder.emit('Cast', [max_x_f16], attrs={'dst_type': ori_dtype})
|
|
38
|
-
else:
|
|
39
|
-
max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
|
|
40
|
-
data_sub = graph_builder.emit('Sub', [input_x, max_x])
|
|
41
|
-
data_exp = graph_builder.emit('Exp', [data_sub])
|
|
42
|
-
data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True})
|
|
43
|
-
log_expsum = graph_builder.emit('Log', [data_expsum])
|
|
44
|
-
result = graph_builder.emit('Sub', [data_sub, log_expsum])
|
|
45
|
-
|
|
46
|
-
return result
|
|
@@ -1,36 +0,0 @@
|
|
|
1
|
-
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for LogSoftmaxGrad"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
17
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
|
|
21
|
-
@VLD.check_attrs('axis')
|
|
22
|
-
class LogSoftmaxGrad(Expander):
|
|
23
|
-
"""LogSoftmaxGrad expander"""
|
|
24
|
-
|
|
25
|
-
def _expand(self, graph_builder):
|
|
26
|
-
input_logits, input_dy = self.inputs
|
|
27
|
-
axis = self.attrs['axis']
|
|
28
|
-
if isinstance(axis, int):
|
|
29
|
-
axis = (axis,)
|
|
30
|
-
|
|
31
|
-
softmax = graph_builder.emit('Exp', [input_logits])
|
|
32
|
-
dy_sum = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': axis, 'keep_dims': True})
|
|
33
|
-
mul_result = graph_builder.emit('Mul', [softmax, dy_sum])
|
|
34
|
-
result = graph_builder.emit('Sub', [input_dy, mul_result])
|
|
35
|
-
|
|
36
|
-
return result
|
|
@@ -1,80 +0,0 @@
|
|
|
1
|
-
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for BatchMatMul and MatMul"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
17
|
-
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
|
|
18
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
@VLD.check_attrs('transpose_a', 'transpose_b', 'left_format', 'right_format')
|
|
22
|
-
class MatMul(Expander):
|
|
23
|
-
"""
|
|
24
|
-
MatMul expander
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
def __init__(self, expand_info):
|
|
28
|
-
super(MatMul, self).__init__(expand_info)
|
|
29
|
-
self.shape_a = self.inputs[0]['shape']
|
|
30
|
-
self.shape_b = self.inputs[1]['shape']
|
|
31
|
-
self.transpose_a = False
|
|
32
|
-
self.transpose_b = False
|
|
33
|
-
self.left_format = ""
|
|
34
|
-
self.right_format = ""
|
|
35
|
-
|
|
36
|
-
def _optimize_to_mul(self):
|
|
37
|
-
"""check if matmul can be replace by mul"""
|
|
38
|
-
if self.processor != 'aicore' or self.left_format != DF.DEFAULT or self.right_format != DF.DEFAULT:
|
|
39
|
-
return False
|
|
40
|
-
k_a = self.shape_a[-2] if self.transpose_a else self.shape_a[-1]
|
|
41
|
-
k_b = self.shape_b[-1] if self.transpose_b else self.shape_b[-2]
|
|
42
|
-
if k_a != 1 or k_b != 1:
|
|
43
|
-
return False
|
|
44
|
-
return True
|
|
45
|
-
|
|
46
|
-
def _check(self):
|
|
47
|
-
input_num = len(self.inputs)
|
|
48
|
-
if input_num < 2:
|
|
49
|
-
raise GKException("For 'MatMul', inputs number should bigger than 1, but got {}.".format(input_num))
|
|
50
|
-
|
|
51
|
-
def _expand(self, graph_builder):
|
|
52
|
-
self.transpose_a = self.attrs['transpose_a']
|
|
53
|
-
self.transpose_b = self.attrs['transpose_b']
|
|
54
|
-
self.left_format = self.attrs['left_format']
|
|
55
|
-
self.right_format = self.attrs['right_format']
|
|
56
|
-
|
|
57
|
-
def transpose(shape):
|
|
58
|
-
trans_shape = list(shape)
|
|
59
|
-
trans_shape[-2] = shape[-1]
|
|
60
|
-
trans_shape[-1] = shape[-2]
|
|
61
|
-
return trans_shape
|
|
62
|
-
if not self._optimize_to_mul():
|
|
63
|
-
raise GKException("MatMul/BatchMatMul do not need to be replaced by Mul")
|
|
64
|
-
# Matmul is replaced by Mul([b m k], [b k n]) when k==1
|
|
65
|
-
input_a = self.inputs[0]
|
|
66
|
-
input_b = self.inputs[1]
|
|
67
|
-
if self.transpose_a:
|
|
68
|
-
shape_a_trans = transpose(self.shape_a)
|
|
69
|
-
input_a = graph_builder.emit('Reshape', [input_a], attrs={'shape': shape_a_trans})
|
|
70
|
-
if self.transpose_b:
|
|
71
|
-
shape_b_trans = transpose(self.shape_b)
|
|
72
|
-
input_b = graph_builder.emit('Reshape', [input_b], attrs={'shape': shape_b_trans})
|
|
73
|
-
result = graph_builder.emit('Mul', [input_a, input_b])
|
|
74
|
-
if 'dst_type' in self.attrs and self.inputs[0].dtype != self.attrs['dst_type']:
|
|
75
|
-
result = graph_builder.emit('Cast', [result], attrs={'dst_type': self.attrs['dst_type']})
|
|
76
|
-
return result
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
class BatchMatMul(MatMul):
|
|
80
|
-
"""BatchMatMul expander"""
|
|
@@ -1,59 +0,0 @@
|
|
|
1
|
-
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for maximum_grad"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
|
|
17
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
18
|
-
from .minimum_grad import MinimumGrad
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
@VLD.check_all_formats_same
|
|
22
|
-
class MaximumGrad(Expander):
|
|
23
|
-
"""MaximumGrad expander"""
|
|
24
|
-
|
|
25
|
-
def _check(self):
|
|
26
|
-
if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True):
|
|
27
|
-
raise GKException("For 'MaximumGrad', value of attr 'grad_x' and 'grad_y' should be False, but got {} and "
|
|
28
|
-
"{}".format(self.attrs.get('grad_x'), self.attrs.get('grad_y')))
|
|
29
|
-
return super()._check()
|
|
30
|
-
|
|
31
|
-
def _expand(self, graph_builder):
|
|
32
|
-
input_x, input_y, input_dout = self.inputs
|
|
33
|
-
ge_result = graph_builder.emit('GreaterEqual', [input_x, input_y])
|
|
34
|
-
ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': input_x.dtype})
|
|
35
|
-
dx = graph_builder.emit('Mul', [ge_result, input_dout])
|
|
36
|
-
dy = graph_builder.emit('Sub', [input_dout, dx])
|
|
37
|
-
|
|
38
|
-
reduce_axis_x = MinimumGrad.get_reduce_axis(input_x.shape, dx.shape)
|
|
39
|
-
reduce_axis_y = MinimumGrad.get_reduce_axis(input_y.shape, dy.shape)
|
|
40
|
-
if reduce_axis_x:
|
|
41
|
-
dx_reduce = graph_builder.emit('ReduceSum', [dx], attrs={'reduce_axis': reduce_axis_x, 'keep_dims': False})
|
|
42
|
-
if dx_reduce.shape != input_x.shape:
|
|
43
|
-
dx_out = graph_builder.emit('Reshape', [dx_reduce], attrs={'shape': input_x.shape})
|
|
44
|
-
else:
|
|
45
|
-
dx_out = dx_reduce
|
|
46
|
-
else:
|
|
47
|
-
dx_out = dx
|
|
48
|
-
|
|
49
|
-
if reduce_axis_y:
|
|
50
|
-
dy_reduce = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': reduce_axis_y, 'keep_dims': False})
|
|
51
|
-
if dy_reduce.shape != input_y.shape:
|
|
52
|
-
dy_out = graph_builder.emit('Reshape', [dy_reduce], attrs={'shape': input_y.shape})
|
|
53
|
-
else:
|
|
54
|
-
dy_out = dy_reduce
|
|
55
|
-
else:
|
|
56
|
-
dy_out = dy
|
|
57
|
-
|
|
58
|
-
# output two results, regardless of grad_x and grad_y
|
|
59
|
-
return dx_out, dy_out
|