mindspore 2.3.0rc1__cp37-none-any.whl → 2.3.0rc2__cp37-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +1 -1
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +13 -3
- mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_checkparam.py +20 -0
- mindspore/_extends/parse/parser.py +1 -1
- mindspore/_extends/parse/standard_method.py +6 -5
- mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +5 -5
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost_cell_wrapper.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -2
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_stub_tensor.py +1 -0
- mindspore/common/api.py +56 -4
- mindspore/common/dtype.py +5 -3
- mindspore/common/dump.py +2 -2
- mindspore/common/hook_handle.py +51 -4
- mindspore/common/initializer.py +1 -1
- mindspore/common/jit_config.py +17 -6
- mindspore/common/parameter.py +7 -2
- mindspore/common/recompute.py +247 -0
- mindspore/common/sparse_tensor.py +2 -2
- mindspore/common/symbol.py +1 -1
- mindspore/common/tensor.py +74 -36
- mindspore/communication/__init__.py +3 -3
- mindspore/communication/management.py +30 -30
- mindspore/context.py +28 -15
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +51 -51
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +3 -3
- mindspore/dataset/engine/datasets_audio.py +14 -14
- mindspore/dataset/engine/datasets_standard_format.py +3 -3
- mindspore/dataset/engine/datasets_text.py +38 -38
- mindspore/dataset/engine/datasets_user_defined.py +3 -3
- mindspore/dataset/engine/datasets_vision.py +68 -68
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +26 -26
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/transforms.py +92 -92
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/experimental/optim/adadelta.py +2 -2
- mindspore/experimental/optim/adagrad.py +2 -2
- mindspore/experimental/optim/adam.py +2 -2
- mindspore/experimental/optim/adamax.py +2 -2
- mindspore/experimental/optim/adamw.py +2 -2
- mindspore/experimental/optim/asgd.py +2 -2
- mindspore/experimental/optim/lr_scheduler.py +24 -20
- mindspore/experimental/optim/nadam.py +2 -2
- mindspore/experimental/optim/optimizer.py +1 -1
- mindspore/experimental/optim/radam.py +2 -2
- mindspore/experimental/optim/rmsprop.py +2 -2
- mindspore/experimental/optim/rprop.py +2 -2
- mindspore/experimental/optim/sgd.py +2 -2
- mindspore/hal/stream.py +2 -0
- mindspore/include/mindapi/base/types.h +5 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +6 -6
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/liblowlatency_collective.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/DeviceBin +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/PkgInspect +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/op_man +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/device/ascend910b/bin/ascend910b.bin +101787 -98559
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_cann_host.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_host.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/base/op_register.h +2 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/params/mix.h +8 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/params/norm.h +5 -3
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/params/reduce.h +2 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/backend/backend.h +3 -3
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/backend/rtbackend.h +3 -3
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/base/types.h +0 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/module/module.h +3 -3
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/svector/svector.h +3 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops_static.a +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/add/tiling/add_tiling.h +9 -9
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/apply_rotary_pos_emb_impl.h +2 -6
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb.h +2 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_base.h +460 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_bf16.h +217 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_fp16.h +116 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_tiling.h +16 -24
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_value.h +27 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/asdop/asd_op_impl.h +0 -4
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/FlashAttentionScore_impl.h → flash_attention_score/flash_attention_score_impl.h} +2 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/bs_attention_tiling.h → flash_attention_score/flash_attention_score_tiling.h} +15 -19
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/gelu/tiling/gelu_tiling.h +7 -9
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/lccl/lccl_wrapper.h +58 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul/matmul_impl.h +19 -8
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{matmul → matmul_common}/pp_matmul_common_tiling.h +18 -8
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{matmul → matmul_common}/pp_matmul_info.h +7 -4
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{matmul → matmul_common}/tiling_data.h +44 -6
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_common/tiling_utils.h +65 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/matmul_stridedslice_fusion_impl.h +10 -6
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/op_param.h +4 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/kernel/paged_attention_mix_hwsync.h +41 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/PagedAttention_impl.h → paged_attention/paged_attention_impl.h} +1 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/paged_attention_tiling.h +63 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/add_param.h +2 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention_param.h → param/attention_param.h} +11 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/matmul_ext_param.h +37 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/sub_param.h +45 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/reshape_and_cache/reshape_and_cache_tiling.h +1 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm.h +23 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm_base.h +175 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm_normal.h +276 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm_split_d.h +280 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/tiling_data.h +35 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/rms_norm_impl.h +45 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/kernel/sub_kernel.h +20 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/sub_impl.h +47 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/sub_tiling.h +25 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/tune_repo/matmul_table.h +323 -23
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/types.h +15 -4
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/log/log_tiling.h +8 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libAdd_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libSub_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_layernorm_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_rms_norm_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libapply_rotary_pos_emb_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libcast_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libgelu_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_stridedslice_fusion_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libms_kernels_internal.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libnot_equal_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libreshape_and_cache_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/librms_norm_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bnsd_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bsh_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bnsd_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bsh_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcal.h +22 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcal_comm.h +70 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcal_types.h +103 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lccl.h +47 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lccl_wrapper.h +58 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcoc.h +154 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblcal.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblccl_wrapper.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/log.py +2 -2
- mindspore/mint/__init__.py +457 -0
- mindspore/mint/nn/__init__.py +430 -0
- mindspore/mint/nn/functional.py +424 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +186 -0
- mindspore/multiprocessing/__init__.py +4 -0
- mindspore/nn/__init__.py +3 -0
- mindspore/nn/cell.py +51 -47
- mindspore/nn/extend/__init__.py +29 -0
- mindspore/nn/extend/basic.py +140 -0
- mindspore/nn/extend/embedding.py +143 -0
- mindspore/nn/extend/layer/__init__.py +27 -0
- mindspore/nn/extend/layer/normalization.py +107 -0
- mindspore/nn/extend/pooling.py +117 -0
- mindspore/nn/generator.py +297 -0
- mindspore/nn/layer/basic.py +109 -1
- mindspore/nn/layer/container.py +2 -2
- mindspore/nn/layer/conv.py +6 -6
- mindspore/nn/layer/embedding.py +1 -1
- mindspore/nn/layer/normalization.py +21 -43
- mindspore/nn/layer/padding.py +4 -0
- mindspore/nn/optim/ada_grad.py +2 -2
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +7 -7
- mindspore/nn/optim/adamax.py +2 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +1 -1
- mindspore/nn/optim/lamb.py +3 -3
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +2 -2
- mindspore/nn/optim/momentum.py +2 -2
- mindspore/nn/optim/optimizer.py +2 -2
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +2 -2
- mindspore/nn/optim/rprop.py +2 -2
- mindspore/nn/optim/sgd.py +2 -2
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +9 -9
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/ops/_grad_experimental/grad_comm_ops.py +4 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -2
- mindspore/ops/_vmap/vmap_math_ops.py +27 -8
- mindspore/ops/_vmap/vmap_nn_ops.py +66 -8
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +73 -1
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +12 -3
- mindspore/ops/auto_generate/gen_arg_handler.py +24 -0
- mindspore/ops/auto_generate/gen_extend_func.py +274 -0
- mindspore/ops/auto_generate/gen_ops_def.py +889 -22
- mindspore/ops/auto_generate/gen_ops_prim.py +3541 -253
- mindspore/ops/auto_generate/pyboost_inner_prim.py +282 -0
- mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +9 -0
- mindspore/ops/extend/__init__.py +9 -1
- mindspore/ops/extend/array_func.py +134 -27
- mindspore/ops/extend/math_func.py +3 -3
- mindspore/ops/extend/nn_func.py +363 -2
- mindspore/ops/function/__init__.py +19 -2
- mindspore/ops/function/array_func.py +463 -439
- mindspore/ops/function/clip_func.py +7 -18
- mindspore/ops/function/grad/grad_func.py +5 -5
- mindspore/ops/function/linalg_func.py +4 -4
- mindspore/ops/function/math_func.py +260 -243
- mindspore/ops/function/nn_func.py +825 -62
- mindspore/ops/function/random_func.py +73 -4
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +1 -1
- mindspore/ops/functional.py +2 -2
- mindspore/ops/op_info_register.py +1 -31
- mindspore/ops/operations/__init__.py +2 -3
- mindspore/ops/operations/_grad_ops.py +2 -107
- mindspore/ops/operations/_inner_ops.py +5 -5
- mindspore/ops/operations/_sequence_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +11 -233
- mindspore/ops/operations/comm_ops.py +32 -32
- mindspore/ops/operations/custom_ops.py +7 -89
- mindspore/ops/operations/manually_defined/ops_def.py +329 -4
- mindspore/ops/operations/math_ops.py +13 -163
- mindspore/ops/operations/nn_ops.py +9 -316
- mindspore/ops/operations/random_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +3 -3
- mindspore/ops/primitive.py +2 -2
- mindspore/ops_generate/arg_dtype_cast.py +12 -3
- mindspore/ops_generate/arg_handler.py +24 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +2 -0
- mindspore/ops_generate/gen_pyboost_func.py +13 -6
- mindspore/ops_generate/pyboost_utils.py +2 -17
- mindspore/parallel/__init__.py +3 -2
- mindspore/parallel/_auto_parallel_context.py +106 -1
- mindspore/parallel/_parallel_serialization.py +34 -2
- mindspore/parallel/_utils.py +16 -0
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/checkpoint_transform.py +249 -77
- mindspore/parallel/cluster/process_entity/_api.py +1 -1
- mindspore/parallel/parameter_broadcast.py +1 -1
- mindspore/parallel/shard.py +1 -1
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +1 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +17 -5
- mindspore/profiler/parser/ascend_msprof_exporter.py +3 -3
- mindspore/profiler/parser/ascend_msprof_generator.py +10 -3
- mindspore/profiler/parser/ascend_op_generator.py +26 -9
- mindspore/profiler/parser/ascend_timeline_generator.py +7 -4
- mindspore/profiler/parser/profiler_info.py +11 -1
- mindspore/profiler/profiling.py +13 -5
- mindspore/rewrite/api/node.py +12 -12
- mindspore/rewrite/api/symbol_tree.py +11 -11
- mindspore/run_check/_check_version.py +1 -1
- mindspore/safeguard/rewrite_obfuscation.py +2 -2
- mindspore/train/amp.py +4 -4
- mindspore/train/anf_ir_pb2.py +8 -2
- mindspore/train/callback/_backup_and_restore.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +2 -2
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
- mindspore/train/callback/_summary_collector.py +2 -2
- mindspore/train/callback/_time_monitor.py +2 -2
- mindspore/train/dataset_helper.py +8 -3
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/mind_ir_pb2.py +22 -17
- mindspore/train/model.py +15 -15
- mindspore/train/serialization.py +18 -18
- mindspore/train/summary/summary_record.py +7 -7
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/version.py +1 -1
- {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/METADATA +1 -1
- {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/RECORD +307 -260
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/tiling_data.h +0 -59
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_bf16_BNSD_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_bf16_BSH_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_fp16_BNSD_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_fp16_BSH_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_bf16_BNSD_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_bf16_BSH_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_fp16_BNSD_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_fp16_BSH_mix.o +0 -0
- /mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/bs_attention_mix_hwsync.h → flash_attention_score/kernel/flash_attention_score_mix_hwsync.h} +0 -0
- {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/top_level.txt +0 -0
mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_common/tiling_utils.h
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright 2024 Huawei Technologies Co., Ltd
|
|
3
|
+
*
|
|
4
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
* you may not use this file except in compliance with the License.
|
|
6
|
+
* You may obtain a copy of the License at
|
|
7
|
+
*
|
|
8
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
*
|
|
10
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
* See the License for the specific language governing permissions and
|
|
14
|
+
* limitations under the License.
|
|
15
|
+
*/
|
|
16
|
+
|
|
17
|
+
#ifndef MATMUL_TILING_UTILS_H
|
|
18
|
+
#define MATMUL_TILING_UTILS_H
|
|
19
|
+
|
|
20
|
+
#include <stdint.h>
|
|
21
|
+
#include <sstream>
|
|
22
|
+
#include <cstdlib>
|
|
23
|
+
#include <vector>
|
|
24
|
+
|
|
25
|
+
namespace mindspore {
|
|
26
|
+
namespace internal {
|
|
27
|
+
namespace tiling {
|
|
28
|
+
|
|
29
|
+
static std::vector<int> getMatMulTilingFromEnv() {
|
|
30
|
+
std::vector<int> result;
|
|
31
|
+
auto env_name = "INTERNAL_MATMUL_TILING";
|
|
32
|
+
const char* envVarValue = std::getenv(env_name);
|
|
33
|
+
|
|
34
|
+
if (envVarValue != nullptr) {
|
|
35
|
+
std::string envVarString(envVarValue);
|
|
36
|
+
std::stringstream ss(envVarString);
|
|
37
|
+
std::string item;
|
|
38
|
+
|
|
39
|
+
while (std::getline(ss, item, ',')) {
|
|
40
|
+
result.push_back(std::stoi(item));
|
|
41
|
+
}
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
return result;
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
static bool getShuffleFlagFromEnv() {
|
|
49
|
+
auto env_name = "CUSTOM_MATMUL_SHUFFLE";
|
|
50
|
+
const char* envVarValue = std::getenv(env_name);
|
|
51
|
+
if (envVarValue != nullptr) {
|
|
52
|
+
std::string envVarString(envVarValue);
|
|
53
|
+
if (envVarString != "0" && envVarString != "off") {
|
|
54
|
+
return true;
|
|
55
|
+
}
|
|
56
|
+
return false;
|
|
57
|
+
}
|
|
58
|
+
return true;
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
} // namespace tiling
|
|
63
|
+
} // namespace internal
|
|
64
|
+
} // namespace mindspore
|
|
65
|
+
#endif // MATMUL_TILING_UTILS_H
|
|
@@ -24,11 +24,12 @@
|
|
|
24
24
|
#include "asdops/tensor.h"
|
|
25
25
|
|
|
26
26
|
#include "utils.h"
|
|
27
|
-
// #include "pp_matmul_info.h"
|
|
28
27
|
#include "backend_param.h"
|
|
29
|
-
#include "
|
|
28
|
+
#include "matmul_common/pp_matmul_info.h"
|
|
29
|
+
#include "matmul_common/tiling_utils.h"
|
|
30
|
+
#include "matmul_common/tiling_data.h"
|
|
31
|
+
#include "matmul_common/pp_matmul_common_tiling.h"
|
|
30
32
|
#include "param/matmul_qkv_param.h"
|
|
31
|
-
// #include "pp_matmul_common_tiling.h"
|
|
32
33
|
#include "tune_repo/utils.h"
|
|
33
34
|
|
|
34
35
|
#include "internal_kernel.h"
|
|
@@ -39,6 +40,8 @@
|
|
|
39
40
|
namespace mindspore {
|
|
40
41
|
namespace internal {
|
|
41
42
|
|
|
43
|
+
using namespace tiling;
|
|
44
|
+
|
|
42
45
|
class MatMulStridedSliceFusionImpl : public InternelKernelImpl {
|
|
43
46
|
public:
|
|
44
47
|
MatMulStridedSliceFusionImpl(const OpParamPtr ¶m) : InternelKernelImpl(param){};
|
|
@@ -48,7 +51,8 @@ class MatMulStridedSliceFusionImpl : public InternelKernelImpl {
|
|
|
48
51
|
int Launch() override;
|
|
49
52
|
size_t GetTilingBufSize() override;
|
|
50
53
|
int Tiling(HostRawBuf &tilingBuf) override;
|
|
51
|
-
|
|
54
|
+
void TilingBasicFromPp(uint32_t &blockDim, PpTilingData &tilingdata);
|
|
55
|
+
int TilingLLMCustom(HostRawBuf &tilingBuf, const uint32_t &blockDim, const PpTilingData &tilingdata, bool has_tuned);
|
|
52
56
|
std::vector<uint64_t> GetWorkSpaceSize() override;
|
|
53
57
|
int InferShape(const std::vector<DIMS> &input_shapes, std::vector<DIMS> &output_shapes) override;
|
|
54
58
|
|
|
@@ -66,8 +70,8 @@ class MatMulStridedSliceFusionImpl : public InternelKernelImpl {
|
|
|
66
70
|
|
|
67
71
|
REPO tuningTable_;
|
|
68
72
|
tiling::MatmulStridedSliceFusionTilingData t_;
|
|
69
|
-
|
|
70
|
-
void
|
|
73
|
+
std::vector<int> GetTunedKey();
|
|
74
|
+
void SetTunedValue(const std::vector<int> &tuned_config);
|
|
71
75
|
};
|
|
72
76
|
|
|
73
77
|
} // namespace internal
|
|
@@ -27,7 +27,6 @@
|
|
|
27
27
|
#include "asdops/params/norm.h"
|
|
28
28
|
#include "asdops/params/softmax.h"
|
|
29
29
|
#include "asdops/params/split.h"
|
|
30
|
-
#include "attention_param.h"
|
|
31
30
|
#include "asdops/params/expand.h"
|
|
32
31
|
#include "asdops/params/fill.h"
|
|
33
32
|
#include "asdops/params/reduce.h"
|
|
@@ -99,6 +98,10 @@ struct AddLayerNormParam {
|
|
|
99
98
|
};
|
|
100
99
|
|
|
101
100
|
struct ApplyRotaryPosEmbParam {
|
|
101
|
+
// cosFormat=0 shape是[maxSeqLen, headDim], cos/sin不交替
|
|
102
|
+
// cosFormat=1 shape是[maxSeqLen, headDim], cos/sin交替
|
|
103
|
+
// cosFormat=2 shape是[batch*seqLen, headDim], cos/sin不交替
|
|
104
|
+
// cosFormat=3 shape是[batch*seqLen, headDim], cos/sin交替
|
|
102
105
|
int32_t cosFormat{0};
|
|
103
106
|
};
|
|
104
107
|
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
#ifndef BS_FLASHATTENTION_BS__ATTENTION_MIX_HWSYNC_H
|
|
2
|
+
#define BS_FLASHATTENTION_BS__ATTENTION_MIX_HWSYNC_H
|
|
3
|
+
constexpr float DROPOUT_PROP = 0.5;
|
|
4
|
+
constexpr uint32_t LOOP_LEN = 5;
|
|
5
|
+
constexpr uint32_t UB_HALF_BUF_SIZE = 8 * 2048;
|
|
6
|
+
constexpr uint32_t BIT_UINT8 = 8;
|
|
7
|
+
constexpr uint32_t BIT_BLOCK = 256;
|
|
8
|
+
constexpr uint32_t BLOCK_SIZE = 16;
|
|
9
|
+
constexpr uint32_t VECTOR_SIZE = 128;
|
|
10
|
+
constexpr uint32_t VECTOR_SIZE_FP32 = 64;
|
|
11
|
+
constexpr uint32_t CUBE_MATRIX_SIZE = 256;// 16 * 16
|
|
12
|
+
constexpr uint64_t UB_UINT8_BLOCK_SIZE = 16384; // 64 * 128 * 2B
|
|
13
|
+
constexpr uint64_t UB_UINT8_LINE_SIZE = 512; // 64 * 4B,申请两倍空间防踩踏。
|
|
14
|
+
constexpr uint64_t UB_FLOAT_LINE_SIZE = 128; // 64,申请两倍空间防踩踏。
|
|
15
|
+
constexpr uint64_t UB_HALF_LINE_SIZE = 256; // UB_FLOAT_LINE_SIZE * 2
|
|
16
|
+
|
|
17
|
+
constexpr uint32_t L0AB_HALF_BUF_SIZE = 16384; // 128 * 128
|
|
18
|
+
constexpr uint64_t L1_SIZE = 512 * 1024; // 512KB
|
|
19
|
+
constexpr uint64_t L0AB_UINT8_BLOCK_SIZE = 32768; // 128 * 128 * 2B
|
|
20
|
+
constexpr uint64_t L1_MAX_SHARE_NUM = (L1_SIZE - 8 * L0AB_UINT8_BLOCK_SIZE) / L0AB_UINT8_BLOCK_SIZE / 2;
|
|
21
|
+
constexpr uint64_t SUB_SP_SIZE = 2048 * 8; // 1024*16, 2048*8, 4096*4, 8192*2, 16K*1,五种分块方法
|
|
22
|
+
|
|
23
|
+
enum class L1Mode{load, // 读取数据至L1的share区
|
|
24
|
+
share, // 使用share区的数据
|
|
25
|
+
noshare}; // 不读且不用share区
|
|
26
|
+
|
|
27
|
+
inline uint64_t ceil(uint64_t y, uint64_t x) {
|
|
28
|
+
return (y + x - 1) / x;
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
inline uint64_t round(uint64_t y, uint64_t x) {
|
|
32
|
+
return ceil(y, x) * x;
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
#if BFLOAT16
|
|
36
|
+
#define CALC_DATA_TYPE bfloat16_t
|
|
37
|
+
#else
|
|
38
|
+
#define CALC_DATA_TYPE half
|
|
39
|
+
#endif
|
|
40
|
+
|
|
41
|
+
#endif //BS_FLASHATTENTION_BS__ATTENTION_MIX_HWSYNC_H
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
#ifndef __BS_ATTENTION_TILING_H__
|
|
2
|
+
#define __BS_ATTENTION_TILING_H__
|
|
3
|
+
|
|
4
|
+
#pragma pack (8)
|
|
5
|
+
typedef struct {
|
|
6
|
+
uint64_t batch_size;
|
|
7
|
+
uint64_t num_heads;
|
|
8
|
+
uint64_t max_seqlen;
|
|
9
|
+
uint64_t head_dim;
|
|
10
|
+
uint64_t num_group;
|
|
11
|
+
uint64_t q_seqlen;
|
|
12
|
+
uint64_t kv_seqlen;
|
|
13
|
+
uint64_t table_block_size;
|
|
14
|
+
uint64_t sync_addr;
|
|
15
|
+
uint64_t core_num;
|
|
16
|
+
float tor;
|
|
17
|
+
} BSAttentionTilingData;
|
|
18
|
+
#pragma pack()
|
|
19
|
+
|
|
20
|
+
#define MAX_CORE_NUM 25
|
|
21
|
+
#define ATTENTION_DEBUG false // 开启时会对S/P写入调试数据
|
|
22
|
+
#define ROWMAX true
|
|
23
|
+
#define OP_NAME PagedAttention
|
|
24
|
+
#define BUFFER_NUM 4 // 核间流水数,暂不支持修改
|
|
25
|
+
constexpr uint64_t WORKSPACE_MAX_SEQLEN = 16384; // max seqlen
|
|
26
|
+
constexpr uint64_t WORKSPACE_SIZE = 64 * WORKSPACE_MAX_SEQLEN;
|
|
27
|
+
|
|
28
|
+
#if BFLOAT16
|
|
29
|
+
#define TYPE_NAME _bf16
|
|
30
|
+
#else
|
|
31
|
+
#define TYPE_NAME _fp16
|
|
32
|
+
#endif
|
|
33
|
+
|
|
34
|
+
#if BSH
|
|
35
|
+
#define LAYOUT_NAME _BSH
|
|
36
|
+
#else
|
|
37
|
+
#define LAYOUT_NAME _BNSD
|
|
38
|
+
#endif
|
|
39
|
+
|
|
40
|
+
#define TRI_NAME _full
|
|
41
|
+
|
|
42
|
+
#define CONCAT_(A, B, C, D, E) A##B##C##D##E
|
|
43
|
+
#define CONCAT(A, B, C, D, E) CONCAT_(A, B, C, D, E)
|
|
44
|
+
#define FUNC_NAME_AIC CONCAT(OP_NAME, TYPE_NAME, LAYOUT_NAME, TRI_NAME, _mix_aic)
|
|
45
|
+
#define FUNC_NAME_AIV CONCAT(OP_NAME, TYPE_NAME, LAYOUT_NAME, TRI_NAME, _mix_aiv)
|
|
46
|
+
|
|
47
|
+
// **************mask patten模式**************//
|
|
48
|
+
// 第一种:下三角,开启LOWER_TRIANGLE时会直接采用下三角,不依赖mask
|
|
49
|
+
// #define LOWER_TRIANGLE false
|
|
50
|
+
|
|
51
|
+
// 第二种:Block Sparse,LOWER_TRIANGLE关闭时,开启BLOCK_SPARSE,会使用pre_token和next_token,不依赖mask(待开发)
|
|
52
|
+
// #define BLOCK_SPARSE false
|
|
53
|
+
|
|
54
|
+
// 第三种:读取MASK,LOWER_TRIANGLE和BLOCK_SPARSE关闭时,开启AMASK,会使用mask作为输入
|
|
55
|
+
// #define AMASK true
|
|
56
|
+
|
|
57
|
+
// 第四种:全矩阵,LOWER_TRIANGLE、BLOCK_SPARSE和AMASK如果全部关闭,则此attention采用全矩阵运算,不抑制S中的元素
|
|
58
|
+
// *******************************************//
|
|
59
|
+
|
|
60
|
+
constexpr uint64_t WORKSPACE_MAX_SEQLEN_BLOCK = WORKSPACE_MAX_SEQLEN / 16;
|
|
61
|
+
constexpr uint64_t BUFFER_SIZE = MAX_CORE_NUM * WORKSPACE_SIZE * sizeof(uint16_t);
|
|
62
|
+
|
|
63
|
+
#endif
|
|
@@ -34,11 +34,11 @@ struct AddParam : public OpParam {
|
|
|
34
34
|
DIMS input1_dims_;
|
|
35
35
|
DIMS input2_dims_;
|
|
36
36
|
bool canSupport() {
|
|
37
|
-
if (ADD_SUPPORT_DTYPE.find(input1_dtype_) == ADD_SUPPORT_DTYPE.end()) {
|
|
37
|
+
if (ADD_SUPPORT_DTYPE.find(input1_dtype_) == ADD_SUPPORT_DTYPE.end() || input1_dims_ != input2_dims_) {
|
|
38
38
|
return false;
|
|
39
39
|
}
|
|
40
40
|
if (input1_dims_ == input2_dims_) {
|
|
41
|
-
return
|
|
41
|
+
return false;
|
|
42
42
|
}
|
|
43
43
|
if (std::abs(int(input1_dims_.size()) - int(input2_dims_.size())) > 1) {
|
|
44
44
|
return false;
|
|
@@ -16,12 +16,21 @@
|
|
|
16
16
|
#ifndef ATTENTION_PARAMS_H
|
|
17
17
|
#define ATTENTION_PARAMS_H
|
|
18
18
|
|
|
19
|
+
#include "types.h"
|
|
20
|
+
#include "op_param.h"
|
|
21
|
+
|
|
19
22
|
namespace mindspore {
|
|
20
23
|
namespace internal {
|
|
21
|
-
struct FlashAttentionScoreParam {
|
|
24
|
+
struct FlashAttentionScoreParam : public OpParam {
|
|
25
|
+
int head_num = 0;
|
|
26
|
+
int inner_precise = 0;
|
|
27
|
+
int pre_tokens = 2147483647;
|
|
28
|
+
int next_tokens = 0;
|
|
29
|
+
int sparse_mode = 0;
|
|
22
30
|
};
|
|
23
31
|
|
|
24
|
-
struct PagedAttentionParam {
|
|
32
|
+
struct PagedAttentionParam : public OpParam {
|
|
33
|
+
int inner_precise = 0;
|
|
25
34
|
};
|
|
26
35
|
} // namespace internal
|
|
27
36
|
} // namespace mindspore
|
mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/matmul_ext_param.h
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright 2024 Huawei Technologies Co., Ltd
|
|
3
|
+
*
|
|
4
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
* you may not use this file except in compliance with the License.
|
|
6
|
+
* You may obtain a copy of the License at
|
|
7
|
+
*
|
|
8
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
*
|
|
10
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
* See the License for the specific language governing permissions and
|
|
14
|
+
* limitations under the License.
|
|
15
|
+
*/
|
|
16
|
+
#ifndef MATMUL_EXT_PARAMS_H_
|
|
17
|
+
#define MATMUL_EXT_PARAMS_H_
|
|
18
|
+
|
|
19
|
+
#include "types.h"
|
|
20
|
+
#include "op_param.h"
|
|
21
|
+
|
|
22
|
+
namespace mindspore {
|
|
23
|
+
namespace internal {
|
|
24
|
+
|
|
25
|
+
struct MatMulExtParam : public OpParam {
|
|
26
|
+
int input_dtype = -1;
|
|
27
|
+
int weight_dtype = -1;
|
|
28
|
+
int output_dtype = -1;
|
|
29
|
+
bool with_relu = false;
|
|
30
|
+
bool with_gelu = false;
|
|
31
|
+
bool with_bias = false;
|
|
32
|
+
bool with_bias_fastgelu = false;
|
|
33
|
+
};
|
|
34
|
+
|
|
35
|
+
} // namespace internal
|
|
36
|
+
} // namespace mindspore
|
|
37
|
+
#endif
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright 2024 Huawei Technologies Co., Ltd
|
|
3
|
+
*
|
|
4
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
* you may not use this file except in compliance with the License.
|
|
6
|
+
* You may obtain a copy of the License at
|
|
7
|
+
*
|
|
8
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
*
|
|
10
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
* See the License for the specific language governing permissions and
|
|
14
|
+
* limitations under the License.
|
|
15
|
+
*/
|
|
16
|
+
#ifndef SUB_PARAMS_H_
|
|
17
|
+
#define SUB_PARAMS_H_
|
|
18
|
+
|
|
19
|
+
#include "types.h"
|
|
20
|
+
#include "op_param.h"
|
|
21
|
+
#include <set>
|
|
22
|
+
|
|
23
|
+
namespace mindspore {
|
|
24
|
+
namespace internal {
|
|
25
|
+
struct SubParam : public OpParam {
|
|
26
|
+
TensorDType input1_dtype_;
|
|
27
|
+
TensorDType input2_dtype_;
|
|
28
|
+
DIMS input1_dims_;
|
|
29
|
+
DIMS input2_dims_;
|
|
30
|
+
bool canSupport() {
|
|
31
|
+
if (input2_dtype_ != AsdOps::TensorDType::TENSOR_DTYPE_INT32) {
|
|
32
|
+
return false;
|
|
33
|
+
}
|
|
34
|
+
if (input2_dims_.size() == 0 || (input2_dims_.size() == 1 && input2_dims_[0] == 1)) {
|
|
35
|
+
return true;
|
|
36
|
+
}
|
|
37
|
+
if (input1_dims_.size() == 0 || (input1_dims_.size() == 1 && input1_dims_[0] == 1)) {
|
|
38
|
+
return true;
|
|
39
|
+
}
|
|
40
|
+
return false;
|
|
41
|
+
}
|
|
42
|
+
};
|
|
43
|
+
} // namespace internal
|
|
44
|
+
} // namespace mindspore
|
|
45
|
+
#endif
|
mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm.h
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright 2024 Huawei Technologies Co., Ltd
|
|
3
|
+
*
|
|
4
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
* you may not use this file except in compliance with the License.
|
|
6
|
+
* You may obtain a copy of the License at
|
|
7
|
+
*
|
|
8
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
*
|
|
10
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
* See the License for the specific language governing permissions and
|
|
14
|
+
* limitations under the License.
|
|
15
|
+
*/
|
|
16
|
+
|
|
17
|
+
#ifndef MS_KERNELS_INTERNAL_KERNEL_ASCENDC_RMS_NORM_H_
|
|
18
|
+
#define MS_KERNELS_INTERNAL_KERNEL_ASCENDC_RMS_NORM_H_
|
|
19
|
+
|
|
20
|
+
void rms_norm_do(uint32_t blockDim, void *l2ctrl, void *stream, uint8_t *x, uint8_t *gamma, uint8_t *y, uint8_t *rstd,
|
|
21
|
+
uint8_t *workspace, uint8_t *tiling);
|
|
22
|
+
|
|
23
|
+
#endif
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright 2024 Huawei Technologies Co., Ltd
|
|
3
|
+
*
|
|
4
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
* you may not use this file except in compliance with the License.
|
|
6
|
+
* You may obtain a copy of the License at
|
|
7
|
+
*
|
|
8
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
*
|
|
10
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
* See the License for the specific language governing permissions and
|
|
14
|
+
* limitations under the License.
|
|
15
|
+
*/
|
|
16
|
+
|
|
17
|
+
/*!
|
|
18
|
+
* \file rms_norm_base.h
|
|
19
|
+
* \brief
|
|
20
|
+
*/
|
|
21
|
+
#ifndef _RMS_NORM_BASE_H_
|
|
22
|
+
#define _RMS_NORM_BASE_H_
|
|
23
|
+
#include "kernel_operator.h"
|
|
24
|
+
|
|
25
|
+
using namespace AscendC;
|
|
26
|
+
|
|
27
|
+
#if __CCE_AICORE__ != 220
|
|
28
|
+
#define bfloat16_t int16_t
|
|
29
|
+
#endif
|
|
30
|
+
constexpr int32_t BUFFER_NUM = 1; // tensor num for each queue
|
|
31
|
+
constexpr int32_t NUM_PER_REP_FP32 = 64; // ONE_REPEAT_BYTE_SIZE / sizeof(float);
|
|
32
|
+
constexpr int32_t NUM_PER_BLK_FP32 = 8;
|
|
33
|
+
constexpr float MINUS_HALF = -0.5;
|
|
34
|
+
constexpr float ZERO = 0;
|
|
35
|
+
constexpr float ONE = 1;
|
|
36
|
+
|
|
37
|
+
template <typename T>
|
|
38
|
+
__aicore__ inline T CeilDiv(T x, T y) {
|
|
39
|
+
return y == 0 ? x : (x + y - 1) / y;
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
template <typename Tp, Tp v>
|
|
43
|
+
struct integral_constant {
|
|
44
|
+
static constexpr Tp value = v;
|
|
45
|
+
};
|
|
46
|
+
using true_type = integral_constant<bool, true>;
|
|
47
|
+
using false_type = integral_constant<bool, false>;
|
|
48
|
+
template <typename, typename>
|
|
49
|
+
struct is_same : public false_type {};
|
|
50
|
+
template <typename Tp>
|
|
51
|
+
struct is_same<Tp, Tp> : public true_type {};
|
|
52
|
+
|
|
53
|
+
__aicore__ inline void ReduceSumFP32(const LocalTensor<float> &dst_local, const LocalTensor<float> &src_local,
|
|
54
|
+
const LocalTensor<float> &work_local, int32_t count) {
|
|
55
|
+
// count need smaller than 255 repeat
|
|
56
|
+
if (g_coreType == AIV) {
|
|
57
|
+
uint64_t mask = NUM_PER_REP_FP32;
|
|
58
|
+
int32_t repeatTimes = count / NUM_PER_REP_FP32;
|
|
59
|
+
int32_t tailCount = count % NUM_PER_REP_FP32;
|
|
60
|
+
int32_t bodyCount = repeatTimes * NUM_PER_REP_FP32;
|
|
61
|
+
BinaryRepeatParams repeatParams;
|
|
62
|
+
repeatParams.src0RepStride = ONE_REPEAT_BYTE_SIZE / ONE_BLK_SIZE;
|
|
63
|
+
repeatParams.src0BlkStride = 1;
|
|
64
|
+
repeatParams.src1RepStride = 0;
|
|
65
|
+
repeatParams.src1BlkStride = 1;
|
|
66
|
+
repeatParams.dstRepStride = 0;
|
|
67
|
+
repeatParams.dstBlkStride = 1;
|
|
68
|
+
Duplicate(work_local, ZERO, NUM_PER_REP_FP32);
|
|
69
|
+
pipe_barrier(PIPE_V);
|
|
70
|
+
if (likely(repeatTimes > 0)) {
|
|
71
|
+
Add(work_local, src_local, work_local, mask, repeatTimes, repeatParams);
|
|
72
|
+
pipe_barrier(PIPE_V);
|
|
73
|
+
}
|
|
74
|
+
if (unlikely(tailCount != 0)) {
|
|
75
|
+
Add(work_local, src_local[bodyCount], work_local, tailCount, 1, repeatParams);
|
|
76
|
+
pipe_barrier(PIPE_V);
|
|
77
|
+
}
|
|
78
|
+
AscendCUtils::SetMask<float>(NUM_PER_REP_FP32);
|
|
79
|
+
vcadd((__ubuf__ float *)dst_local.GetPhyAddr(), (__ubuf__ float *)work_local.GetPhyAddr(), 1, 0, 1, 0, false);
|
|
80
|
+
pipe_barrier(PIPE_V);
|
|
81
|
+
}
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
__aicore__ inline void ReduceSumCustom(const LocalTensor<float> &dst_local, const LocalTensor<float> &src_local,
|
|
85
|
+
const LocalTensor<float> &work_local, int32_t count) {
|
|
86
|
+
#if __CCE_AICORE__ == 220
|
|
87
|
+
ReduceSumFP32(dst_local, src_local, work_local, count);
|
|
88
|
+
#else
|
|
89
|
+
ReduceSum(dst_local, src_local, dst_local, count);
|
|
90
|
+
#endif
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
__aicore__ inline void ReduceSumFP32ToBlock(const LocalTensor<float> &dst_local, const LocalTensor<float> &src_local,
|
|
94
|
+
const LocalTensor<float> &work_local, int32_t count) {
|
|
95
|
+
// count need smaller than 255 repeat
|
|
96
|
+
uint64_t mask = NUM_PER_REP_FP32;
|
|
97
|
+
int32_t repeatTimes = count / NUM_PER_REP_FP32;
|
|
98
|
+
int32_t tailCount = count % NUM_PER_REP_FP32;
|
|
99
|
+
int32_t bodyCount = repeatTimes * NUM_PER_REP_FP32;
|
|
100
|
+
BinaryRepeatParams repeatParams;
|
|
101
|
+
repeatParams.src0RepStride = ONE_REPEAT_BYTE_SIZE / ONE_BLK_SIZE;
|
|
102
|
+
repeatParams.src0BlkStride = 1;
|
|
103
|
+
repeatParams.src1RepStride = 0;
|
|
104
|
+
repeatParams.src1BlkStride = 1;
|
|
105
|
+
repeatParams.dstRepStride = 0;
|
|
106
|
+
repeatParams.dstBlkStride = 1;
|
|
107
|
+
Duplicate(work_local, ZERO, NUM_PER_REP_FP32);
|
|
108
|
+
pipe_barrier(PIPE_V);
|
|
109
|
+
if (likely(repeatTimes > 0)) {
|
|
110
|
+
Add(work_local, src_local, work_local, mask, repeatTimes, repeatParams);
|
|
111
|
+
pipe_barrier(PIPE_V);
|
|
112
|
+
}
|
|
113
|
+
if (unlikely(tailCount != 0)) {
|
|
114
|
+
Add(work_local, src_local[bodyCount], work_local, tailCount, 1, repeatParams);
|
|
115
|
+
pipe_barrier(PIPE_V);
|
|
116
|
+
}
|
|
117
|
+
BlockReduceSum(dst_local, work_local, 1, mask, 1, 1, DEFAULT_REPEAT_STRIDE);
|
|
118
|
+
pipe_barrier(PIPE_V);
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
__aicore__ inline void BlockReduceSumFP32(const LocalTensor<float> &dst_local, const LocalTensor<float> &src_local,
|
|
122
|
+
int32_t count) {
|
|
123
|
+
// count need multiple of 8
|
|
124
|
+
int32_t repeatTimes = count / NUM_PER_REP_FP32;
|
|
125
|
+
int32_t tailCount = count % NUM_PER_REP_FP32;
|
|
126
|
+
int32_t dstAddr = repeatTimes * 8;
|
|
127
|
+
int32_t srcAddr = repeatTimes * NUM_PER_REP_FP32;
|
|
128
|
+
if (likely(repeatTimes > 0)) {
|
|
129
|
+
BlockReduceSum(dst_local, src_local, repeatTimes, NUM_PER_REP_FP32, 1, 1, DEFAULT_REPEAT_STRIDE);
|
|
130
|
+
pipe_barrier(PIPE_V);
|
|
131
|
+
}
|
|
132
|
+
if (tailCount != 0) {
|
|
133
|
+
BlockReduceSum(dst_local[dstAddr], src_local[srcAddr], 1, tailCount, 1, 1, DEFAULT_REPEAT_STRIDE);
|
|
134
|
+
pipe_barrier(PIPE_V);
|
|
135
|
+
}
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
template <typename T, typename U, typename R>
|
|
139
|
+
__aicore__ inline void DataCopyCustom(const U &dstTensor, const R &srcTensor, const uint32_t count) {
|
|
140
|
+
#if __CCE_AICORE__ == 220
|
|
141
|
+
DataCopyParams copyParams;
|
|
142
|
+
copyParams.blockLen = count * sizeof(T);
|
|
143
|
+
copyParams.blockCount = 1;
|
|
144
|
+
if constexpr (is_same<U, AscendC::LocalTensor<T>>::value) {
|
|
145
|
+
DataCopyPadParams padParams;
|
|
146
|
+
DataCopyPad(dstTensor, srcTensor, copyParams, padParams);
|
|
147
|
+
} else {
|
|
148
|
+
DataCopyPad(dstTensor, srcTensor, copyParams);
|
|
149
|
+
}
|
|
150
|
+
#else
|
|
151
|
+
// only support count greater than 32byte
|
|
152
|
+
int32_t numPerBlock = ONE_BLK_SIZE / sizeof(T);
|
|
153
|
+
if (count % numPerBlock == 0) {
|
|
154
|
+
DataCopy(dstTensor, srcTensor, count);
|
|
155
|
+
} else {
|
|
156
|
+
if constexpr (is_same<U, AscendC::LocalTensor<T>>::value) {
|
|
157
|
+
int32_t num = AlignUp(count, numPerBlock);
|
|
158
|
+
DataCopy(dstTensor, srcTensor, num);
|
|
159
|
+
} else {
|
|
160
|
+
int32_t num = count / numPerBlock * numPerBlock;
|
|
161
|
+
DataCopy(dstTensor, srcTensor, num);
|
|
162
|
+
set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0);
|
|
163
|
+
wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0);
|
|
164
|
+
for (int32_t i = 0; i < numPerBlock; i++) {
|
|
165
|
+
T tensorValue = srcTensor.GetValue(count - numPerBlock + i);
|
|
166
|
+
srcTensor.SetValue(i, tensorValue);
|
|
167
|
+
}
|
|
168
|
+
set_flag(PIPE_S, PIPE_MTE3, EVENT_ID0);
|
|
169
|
+
wait_flag(PIPE_S, PIPE_MTE3, EVENT_ID0);
|
|
170
|
+
DataCopy(dstTensor[count - numPerBlock], srcTensor, numPerBlock);
|
|
171
|
+
}
|
|
172
|
+
}
|
|
173
|
+
#endif
|
|
174
|
+
}
|
|
175
|
+
#endif // RMS_NORM_BASE_H_
|