mindspore 2.3.0rc1__cp39-none-any.whl → 2.3.0rc2__cp39-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +1 -1
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +13 -3
- mindspore/_c_dataengine.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/_checkparam.py +20 -0
- mindspore/_extends/parse/parser.py +1 -1
- mindspore/_extends/parse/standard_method.py +6 -5
- mindspore/_mindspore_offline_debug.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +5 -5
- mindspore/boost/boost_cell_wrapper.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -2
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_stub_tensor.py +1 -0
- mindspore/common/api.py +56 -4
- mindspore/common/dtype.py +5 -3
- mindspore/common/dump.py +2 -2
- mindspore/common/hook_handle.py +51 -4
- mindspore/common/initializer.py +1 -1
- mindspore/common/jit_config.py +17 -6
- mindspore/common/parameter.py +7 -2
- mindspore/common/recompute.py +247 -0
- mindspore/common/sparse_tensor.py +2 -2
- mindspore/common/symbol.py +1 -1
- mindspore/common/tensor.py +74 -36
- mindspore/communication/__init__.py +3 -3
- mindspore/communication/management.py +30 -30
- mindspore/context.py +28 -15
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +51 -51
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +3 -3
- mindspore/dataset/engine/datasets_audio.py +14 -14
- mindspore/dataset/engine/datasets_standard_format.py +3 -3
- mindspore/dataset/engine/datasets_text.py +38 -38
- mindspore/dataset/engine/datasets_user_defined.py +3 -3
- mindspore/dataset/engine/datasets_vision.py +68 -68
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +26 -26
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/transforms.py +92 -92
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/experimental/optim/adadelta.py +2 -2
- mindspore/experimental/optim/adagrad.py +2 -2
- mindspore/experimental/optim/adam.py +2 -2
- mindspore/experimental/optim/adamax.py +2 -2
- mindspore/experimental/optim/adamw.py +2 -2
- mindspore/experimental/optim/asgd.py +2 -2
- mindspore/experimental/optim/lr_scheduler.py +24 -20
- mindspore/experimental/optim/nadam.py +2 -2
- mindspore/experimental/optim/optimizer.py +1 -1
- mindspore/experimental/optim/radam.py +2 -2
- mindspore/experimental/optim/rmsprop.py +2 -2
- mindspore/experimental/optim/rprop.py +2 -2
- mindspore/experimental/optim/sgd.py +2 -2
- mindspore/hal/stream.py +2 -0
- mindspore/include/mindapi/base/types.h +5 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +6 -6
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/liblowlatency_collective.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/DeviceBin +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/PkgInspect +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/op_man +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/device/ascend910b/bin/ascend910b.bin +101787 -98559
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_cann_host.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_host.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/base/op_register.h +2 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/params/mix.h +8 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/params/norm.h +5 -3
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/params/reduce.h +2 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/backend/backend.h +3 -3
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/backend/rtbackend.h +3 -3
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/base/types.h +0 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/module/module.h +3 -3
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/svector/svector.h +3 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops_static.a +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/add/tiling/add_tiling.h +9 -9
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/apply_rotary_pos_emb_impl.h +2 -6
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb.h +2 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_base.h +460 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_bf16.h +217 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_fp16.h +116 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_tiling.h +16 -24
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_value.h +27 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/asdop/asd_op_impl.h +0 -4
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/FlashAttentionScore_impl.h → flash_attention_score/flash_attention_score_impl.h} +2 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/bs_attention_tiling.h → flash_attention_score/flash_attention_score_tiling.h} +15 -19
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/gelu/tiling/gelu_tiling.h +7 -9
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/lccl/lccl_wrapper.h +58 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul/matmul_impl.h +19 -8
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{matmul → matmul_common}/pp_matmul_common_tiling.h +18 -8
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{matmul → matmul_common}/pp_matmul_info.h +7 -4
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{matmul → matmul_common}/tiling_data.h +44 -6
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_common/tiling_utils.h +65 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/matmul_stridedslice_fusion_impl.h +10 -6
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/op_param.h +4 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/kernel/paged_attention_mix_hwsync.h +41 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/PagedAttention_impl.h → paged_attention/paged_attention_impl.h} +1 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/paged_attention_tiling.h +63 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/add_param.h +2 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention_param.h → param/attention_param.h} +11 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/matmul_ext_param.h +37 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/sub_param.h +45 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/reshape_and_cache/reshape_and_cache_tiling.h +1 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm.h +23 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm_base.h +175 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm_normal.h +276 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/rms_norm_split_d.h +280 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/kernel/tiling_data.h +35 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/rms_norm_impl.h +45 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/kernel/sub_kernel.h +20 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/sub_impl.h +47 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/sub_tiling.h +25 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/tune_repo/matmul_table.h +323 -23
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/types.h +15 -4
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/log/log_tiling.h +8 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libAdd_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libSub_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_layernorm_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_rms_norm_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libapply_rotary_pos_emb_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libcast_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libgelu_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_stridedslice_fusion_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libms_kernels_internal.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libnot_equal_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libreshape_and_cache_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/librms_norm_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bnsd_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bsh_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bnsd_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bsh_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcal.h +22 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcal_comm.h +70 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcal_types.h +103 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lccl.h +47 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lccl_wrapper.h +58 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/include/lcoc.h +154 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblcal.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblccl_wrapper.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/log.py +2 -2
- mindspore/mint/__init__.py +457 -0
- mindspore/mint/nn/__init__.py +430 -0
- mindspore/mint/nn/functional.py +424 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +186 -0
- mindspore/multiprocessing/__init__.py +4 -0
- mindspore/nn/__init__.py +3 -0
- mindspore/nn/cell.py +51 -47
- mindspore/nn/extend/__init__.py +29 -0
- mindspore/nn/extend/basic.py +140 -0
- mindspore/nn/extend/embedding.py +143 -0
- mindspore/nn/extend/layer/__init__.py +27 -0
- mindspore/nn/extend/layer/normalization.py +107 -0
- mindspore/nn/extend/pooling.py +117 -0
- mindspore/nn/generator.py +297 -0
- mindspore/nn/layer/basic.py +109 -1
- mindspore/nn/layer/container.py +2 -2
- mindspore/nn/layer/conv.py +6 -6
- mindspore/nn/layer/embedding.py +1 -1
- mindspore/nn/layer/normalization.py +21 -43
- mindspore/nn/layer/padding.py +4 -0
- mindspore/nn/optim/ada_grad.py +2 -2
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +7 -7
- mindspore/nn/optim/adamax.py +2 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +1 -1
- mindspore/nn/optim/lamb.py +3 -3
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +2 -2
- mindspore/nn/optim/momentum.py +2 -2
- mindspore/nn/optim/optimizer.py +2 -2
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +2 -2
- mindspore/nn/optim/rprop.py +2 -2
- mindspore/nn/optim/sgd.py +2 -2
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +9 -9
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/ops/_grad_experimental/grad_comm_ops.py +4 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -2
- mindspore/ops/_vmap/vmap_math_ops.py +27 -8
- mindspore/ops/_vmap/vmap_nn_ops.py +66 -8
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +73 -1
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +12 -3
- mindspore/ops/auto_generate/gen_arg_handler.py +24 -0
- mindspore/ops/auto_generate/gen_extend_func.py +274 -0
- mindspore/ops/auto_generate/gen_ops_def.py +889 -22
- mindspore/ops/auto_generate/gen_ops_prim.py +3541 -253
- mindspore/ops/auto_generate/pyboost_inner_prim.py +282 -0
- mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +9 -0
- mindspore/ops/extend/__init__.py +9 -1
- mindspore/ops/extend/array_func.py +134 -27
- mindspore/ops/extend/math_func.py +3 -3
- mindspore/ops/extend/nn_func.py +363 -2
- mindspore/ops/function/__init__.py +19 -2
- mindspore/ops/function/array_func.py +463 -439
- mindspore/ops/function/clip_func.py +7 -18
- mindspore/ops/function/grad/grad_func.py +5 -5
- mindspore/ops/function/linalg_func.py +4 -4
- mindspore/ops/function/math_func.py +260 -243
- mindspore/ops/function/nn_func.py +825 -62
- mindspore/ops/function/random_func.py +73 -4
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +1 -1
- mindspore/ops/functional.py +2 -2
- mindspore/ops/op_info_register.py +1 -31
- mindspore/ops/operations/__init__.py +2 -3
- mindspore/ops/operations/_grad_ops.py +2 -107
- mindspore/ops/operations/_inner_ops.py +5 -5
- mindspore/ops/operations/_sequence_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +11 -233
- mindspore/ops/operations/comm_ops.py +32 -32
- mindspore/ops/operations/custom_ops.py +7 -89
- mindspore/ops/operations/manually_defined/ops_def.py +329 -4
- mindspore/ops/operations/math_ops.py +13 -163
- mindspore/ops/operations/nn_ops.py +9 -316
- mindspore/ops/operations/random_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +3 -3
- mindspore/ops/primitive.py +2 -2
- mindspore/ops_generate/arg_dtype_cast.py +12 -3
- mindspore/ops_generate/arg_handler.py +24 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +2 -0
- mindspore/ops_generate/gen_pyboost_func.py +13 -6
- mindspore/ops_generate/pyboost_utils.py +2 -17
- mindspore/parallel/__init__.py +3 -2
- mindspore/parallel/_auto_parallel_context.py +106 -1
- mindspore/parallel/_parallel_serialization.py +34 -2
- mindspore/parallel/_utils.py +16 -0
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/checkpoint_transform.py +249 -77
- mindspore/parallel/cluster/process_entity/_api.py +1 -1
- mindspore/parallel/parameter_broadcast.py +1 -1
- mindspore/parallel/shard.py +1 -1
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +1 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +17 -5
- mindspore/profiler/parser/ascend_msprof_exporter.py +3 -3
- mindspore/profiler/parser/ascend_msprof_generator.py +10 -3
- mindspore/profiler/parser/ascend_op_generator.py +26 -9
- mindspore/profiler/parser/ascend_timeline_generator.py +7 -4
- mindspore/profiler/parser/profiler_info.py +11 -1
- mindspore/profiler/profiling.py +13 -5
- mindspore/rewrite/api/node.py +12 -12
- mindspore/rewrite/api/symbol_tree.py +11 -11
- mindspore/run_check/_check_version.py +1 -1
- mindspore/safeguard/rewrite_obfuscation.py +2 -2
- mindspore/train/amp.py +4 -4
- mindspore/train/anf_ir_pb2.py +8 -2
- mindspore/train/callback/_backup_and_restore.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +2 -2
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
- mindspore/train/callback/_summary_collector.py +2 -2
- mindspore/train/callback/_time_monitor.py +2 -2
- mindspore/train/dataset_helper.py +8 -3
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/mind_ir_pb2.py +22 -17
- mindspore/train/model.py +15 -15
- mindspore/train/serialization.py +18 -18
- mindspore/train/summary/summary_record.py +7 -7
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/version.py +1 -1
- {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/METADATA +1 -1
- {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/RECORD +307 -260
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/tiling_data.h +0 -59
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_bf16_BNSD_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_bf16_BSH_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_fp16_BNSD_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/FlashAttentionScore_fp16_BSH_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_bf16_BNSD_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_bf16_BSH_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_fp16_BNSD_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/PagedAttention_fp16_BSH_mix.o +0 -0
- /mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/{attention/bs_attention_mix_hwsync.h → flash_attention_score/kernel/flash_attention_score_mix_hwsync.h} +0 -0
- {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
|
|
3
|
+
*
|
|
4
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
* you may not use this file except in compliance with the License.
|
|
6
|
+
* You may obtain a copy of the License at
|
|
7
|
+
*
|
|
8
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
*
|
|
10
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
* See the License for the specific language governing permissions and
|
|
14
|
+
* limitations under the License.
|
|
15
|
+
*/
|
|
16
|
+
#ifndef ROTARY_POS_EMB_FP16
|
|
17
|
+
#define ROTARY_POS_EMB_FP16
|
|
18
|
+
#include "apply_rotary_pos_emb_base.h"
|
|
19
|
+
template <typename QK_DTYPE, typename COS_DTYPE, bool IF_COS_BROADCAST>
|
|
20
|
+
class RopeFp16 : public RopeBase<QK_DTYPE, COS_DTYPE, IF_COS_BROADCAST> {
|
|
21
|
+
public:
|
|
22
|
+
__aicore__ inline RopeFp16(RopeTilingData *tilingData) : RopeBase<QK_DTYPE, COS_DTYPE, IF_COS_BROADCAST>(tilingData) {
|
|
23
|
+
this->repeatSize_ = 128; // 128 = 256B / sizeof(half)
|
|
24
|
+
this->maxProcessNum_ = this->tilingData_->maxUbSize / sizeof(uint16_t);
|
|
25
|
+
this->repeatTimesQ_ = (this->tilingData_->hiddenSizeQ + this->repeatSize_ - 1) / this->repeatSize_;
|
|
26
|
+
this->repeatTimesK_ = (this->tilingData_->hiddenSizeK + this->repeatSize_ - 1) / this->repeatSize_;
|
|
27
|
+
headDimAlign_ = ((this->tilingData_->headDim + ELE_NUM_FP16 - 1) / ELE_NUM_FP16) * ELE_NUM_FP16;
|
|
28
|
+
this->alignHalfHeadDim_ = (this->rotateStride_ * NUM_TWO) % ELE_NUM_FP16;
|
|
29
|
+
this->hiddenSizeAlign_ = ((this->hiddenSize_ + this->repeatSize_ - 1) / this->repeatSize_) * this->repeatSize_;
|
|
30
|
+
|
|
31
|
+
this->cosPad_ = 0;
|
|
32
|
+
this->sinPad_ = this->cosPad_ + this->hiddenSizeAlign_;
|
|
33
|
+
this->negOne_ = this->sinPad_ + this->hiddenSizeAlign_;
|
|
34
|
+
this->oriPos_ = this->negOne_ + this->hiddenSizeAlign_;
|
|
35
|
+
this->padBefore_ = this->oriPos_ + this->hiddenSizeAlign_;
|
|
36
|
+
this->removeBefore_ = this->padBefore_ + this->hiddenSizeAlign_;
|
|
37
|
+
sinResPos_ = this->removeBefore_ + this->hiddenSizeAlign_;
|
|
38
|
+
this->repeatTimes_ = this->hiddenSizeAlign_ / this->repeatSize_;
|
|
39
|
+
|
|
40
|
+
this->syncOffset_ =
|
|
41
|
+
(this->tilingData_->headDim % ELE_NUM_FP16 == 0) ? this->hiddenSizeAlign_ : this->headNum_ * headDimAlign_;
|
|
42
|
+
this->offsetExtraGm_ = NUM_TWO * block_idx * this->syncOffset_;
|
|
43
|
+
this->pipe_.InitBuffer(outQueueCO2_, 1, ((this->maxProcessNum_ - this->batchSize_ * NUM_TWO) * sizeof(QK_DTYPE)));
|
|
44
|
+
AscendC::LocalTensor<QK_DTYPE> cache_perloop_ub_ = outQueueCO2_.AllocTensor<QK_DTYPE>();
|
|
45
|
+
commonUbuf_ = (__ubuf__ QK_DTYPE *)cache_perloop_ub_.GetPhyAddr();
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
__aicore__ inline void Process(__gm__ uint8_t *extraGm) {
|
|
49
|
+
if (this->tilingData_->cosFormat == 1) {
|
|
50
|
+
pipe_barrier((PIPE_ALL));
|
|
51
|
+
this->ExpandCosSin(commonUbuf_, this->cosGm_, (__gm__ COS_DTYPE *)extraGm);
|
|
52
|
+
this->cosGm_ = (__gm__ COS_DTYPE *)extraGm;
|
|
53
|
+
pipe_barrier((PIPE_ALL));
|
|
54
|
+
this->ExpandCosSin(commonUbuf_, this->sinGm_,
|
|
55
|
+
(__gm__ COS_DTYPE *)extraGm + this->tilingData_->ntokens * this->tilingData_->headDim);
|
|
56
|
+
this->sinGm_ = (__gm__ COS_DTYPE *)extraGm + this->tilingData_->ntokens * this->tilingData_->headDim;
|
|
57
|
+
extraGm =
|
|
58
|
+
extraGm + this->tilingData_->ntokens * this->tilingData_->headDim * 4; // sizeof(uint8_t) * 2 = sizeof(half)
|
|
59
|
+
pipe_barrier((PIPE_ALL));
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
this->ExpandNeg(commonUbuf_, sinResPos_, this->headNum_, this->repeatTimes_); // 根据是否对齐选择1 -1 還是 -1 0
|
|
63
|
+
for (uint32_t zz = 0; zz < this->dynamicRound_; ++zz) {
|
|
64
|
+
this->CosSinBroadcast(extraGm, zz, commonUbuf_, this->tilingData_->hiddenSizeQ); // cos sin 和 QK 无关
|
|
65
|
+
|
|
66
|
+
this->QkComm(this->qGm_ + block_idx * this->nlCoreRun_ * this->tilingData_->hiddenSizeQ +
|
|
67
|
+
zz * this->tilingData_->hiddenSizeQ,
|
|
68
|
+
extraGm, this->tilingData_->hiddenSizeQ, commonUbuf_, this->tilingData_->headNumQ);
|
|
69
|
+
|
|
70
|
+
if (this->alignRotary_ == 0) {
|
|
71
|
+
pipe_barrier((PIPE_V));
|
|
72
|
+
this->CalcRopeAlign(commonUbuf_, this->repeatTimesQ_, this->oriPos_, this->removeBefore_, this->padBefore_);
|
|
73
|
+
} else {
|
|
74
|
+
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1);
|
|
75
|
+
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1);
|
|
76
|
+
this->CalcRope(commonUbuf_, this->repeatTimesQ_, this->oriPos_, this->removeBefore_, this->padBefore_,
|
|
77
|
+
sinResPos_, this->padBefore_);
|
|
78
|
+
}
|
|
79
|
+
pipe_barrier((PIPE_ALL)); // 需要
|
|
80
|
+
copy_ubuf_to_gm(this->outQGm_ + block_idx * this->nlCoreRun_ * this->tilingData_->hiddenSizeQ +
|
|
81
|
+
zz * this->tilingData_->hiddenSizeQ,
|
|
82
|
+
commonUbuf_ + this->padBefore_, 0, 1, this->tilingData_->hiddenSizeQ / ELE_NUM_FP16, 0, 0);
|
|
83
|
+
|
|
84
|
+
set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID1);
|
|
85
|
+
wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID1);
|
|
86
|
+
|
|
87
|
+
this->QkComm(this->kGm_ + block_idx * this->nlCoreRun_ * this->tilingData_->hiddenSizeK +
|
|
88
|
+
zz * this->tilingData_->hiddenSizeK,
|
|
89
|
+
extraGm, this->tilingData_->hiddenSizeK, commonUbuf_, this->tilingData_->headNumK);
|
|
90
|
+
|
|
91
|
+
if (this->alignRotary_ == 0) {
|
|
92
|
+
pipe_barrier((PIPE_V));
|
|
93
|
+
this->CalcRopeAlign(commonUbuf_, this->repeatTimesK_, this->oriPos_, this->removeBefore_, this->padBefore_);
|
|
94
|
+
} else {
|
|
95
|
+
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1);
|
|
96
|
+
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1);
|
|
97
|
+
this->CalcRope(commonUbuf_, this->repeatTimesK_, this->oriPos_, this->removeBefore_, this->padBefore_,
|
|
98
|
+
sinResPos_, this->padBefore_);
|
|
99
|
+
}
|
|
100
|
+
pipe_barrier((PIPE_ALL)); // 需要
|
|
101
|
+
copy_ubuf_to_gm(this->outKGm_ + block_idx * this->nlCoreRun_ * this->tilingData_->hiddenSizeK +
|
|
102
|
+
zz * this->tilingData_->hiddenSizeK,
|
|
103
|
+
commonUbuf_ + this->padBefore_, 0, 1, this->tilingData_->hiddenSizeK / ELE_NUM_FP16, 0, 0);
|
|
104
|
+
set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1);
|
|
105
|
+
wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1);
|
|
106
|
+
}
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
private:
|
|
110
|
+
AscendC::TQue<AscendC::QuePosition::VECIN, 1> outQueueCO2_;
|
|
111
|
+
__ubuf__ QK_DTYPE *commonUbuf_{nullptr};
|
|
112
|
+
uint32_t headDimAlign_; // 对齐的headDim
|
|
113
|
+
uint32_t sinResPos_{0}; // fp32的buf中0 0 0 1 1 1的位置
|
|
114
|
+
};
|
|
115
|
+
|
|
116
|
+
#endif
|
|
@@ -14,38 +14,30 @@
|
|
|
14
14
|
* limitations under the License.
|
|
15
15
|
*/
|
|
16
16
|
|
|
17
|
-
#ifndef
|
|
18
|
-
#define
|
|
17
|
+
#ifndef MS_KERNELS_INTERNAL_KERNEL_ASCENDC_ROPE_TILING_DATA_H_
|
|
18
|
+
#define MS_KERNELS_INTERNAL_KERNEL_ASCENDC_ROPE_TILING_DATA_H_
|
|
19
19
|
|
|
20
20
|
#include <stdint.h>
|
|
21
21
|
|
|
22
|
-
struct
|
|
22
|
+
struct RopeTilingData {
|
|
23
23
|
uint32_t hiddenSizeQ{16};
|
|
24
24
|
uint32_t hiddenSizeK{16};
|
|
25
|
-
uint32_t headDim{1};
|
|
25
|
+
uint32_t headDim{1}; // qk头长度的最大值
|
|
26
26
|
uint32_t headNumQ{1};
|
|
27
27
|
uint32_t headNumK{1};
|
|
28
|
-
uint32_t rotaryCoeff{4};
|
|
29
|
-
uint32_t ntokens{1};
|
|
30
|
-
uint32_t
|
|
31
|
-
uint32_t
|
|
32
|
-
uint32_t
|
|
33
|
-
uint32_t
|
|
34
|
-
uint32_t
|
|
35
|
-
uint32_t maxUbSize{0};
|
|
28
|
+
uint32_t rotaryCoeff{4}; // 旋转系数
|
|
29
|
+
uint32_t ntokens{1}; // 总token数
|
|
30
|
+
uint32_t realCore{0}; // 实际用到核数
|
|
31
|
+
uint32_t cosFormat{0}; // 是否复用cos sin
|
|
32
|
+
uint32_t batch{32}; // 几个batch
|
|
33
|
+
uint32_t maxUbSize{0}; // 最大UB内存
|
|
34
|
+
uint32_t tilingId{0};
|
|
36
35
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
int32_t numHeadQ;
|
|
43
|
-
int32_t numHeadK;
|
|
44
|
-
// int32_t hiddenDim;
|
|
45
|
-
int32_t seqLen;
|
|
46
|
-
int32_t maxSeqLen;
|
|
47
|
-
|
|
48
|
-
int32_t posSize; // seqLen==1 ? batch : seqLen
|
|
36
|
+
uint32_t seqLen;
|
|
37
|
+
uint32_t broadCastCos{0};
|
|
38
|
+
uint32_t posDtype;
|
|
39
|
+
uint32_t posSize;
|
|
40
|
+
uint32_t maxSeqLen;
|
|
49
41
|
};
|
|
50
42
|
|
|
51
43
|
#endif
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
|
|
3
|
+
*
|
|
4
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
* you may not use this file except in compliance with the License.
|
|
6
|
+
* You may obtain a copy of the License at
|
|
7
|
+
*
|
|
8
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
*
|
|
10
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
* See the License for the specific language governing permissions and
|
|
14
|
+
* limitations under the License.
|
|
15
|
+
*/
|
|
16
|
+
#ifndef COMMON_VAL_H
|
|
17
|
+
#define COMMON_VAL_H
|
|
18
|
+
const constexpr uint32_t NUM_TWO = 2; // 2
|
|
19
|
+
const constexpr uint32_t BLK_SIZE = 32; // 一个block字节数
|
|
20
|
+
const constexpr uint32_t ELE_NUM_FP16 = 16; // 一个block fp16元素个数
|
|
21
|
+
const constexpr uint32_t ELE_NUM_FP32 = 8; // 一个block字节数 fp32元素个数
|
|
22
|
+
const constexpr uint32_t MAX_LEN_FP16 = 8192; // 非fp16情况下最大长度(hiddensize)
|
|
23
|
+
const constexpr uint8_t DEFAULT_REPEAT_STRIDE = 8; // 默认stride, 8 * 32 = 256
|
|
24
|
+
const constexpr int64_t REG_910B = 48; // 饱和模式寄存器位置
|
|
25
|
+
const constexpr int64_t REG_310P = 53; // 饱和模式寄存器位置
|
|
26
|
+
const constexpr int64_t SLICE_SIZE = 4096; // 切片大小
|
|
27
|
+
#endif
|
|
@@ -41,8 +41,6 @@ class AsdOpsImpl : public InternelKernelImpl {
|
|
|
41
41
|
std::vector<uint64_t> GetWorkSpaceSize() override;
|
|
42
42
|
int InferShape(const std::vector<DIMS> &input_shapes, std::vector<DIMS> &output_shapes) override;
|
|
43
43
|
|
|
44
|
-
void SetCacheInfo(const CacheInfo &cache_info);
|
|
45
|
-
|
|
46
44
|
private:
|
|
47
45
|
AsdOps::Tactic *InitAndGetTactic();
|
|
48
46
|
|
|
@@ -52,8 +50,6 @@ class AsdOpsImpl : public InternelKernelImpl {
|
|
|
52
50
|
AsdOps::LaunchParam launch_param_;
|
|
53
51
|
AsdOps::OpDesc op_desc_;
|
|
54
52
|
bool validated_ = false;
|
|
55
|
-
|
|
56
|
-
RunInfo run_info_;
|
|
57
53
|
};
|
|
58
54
|
|
|
59
55
|
} // namespace internal
|
|
@@ -24,7 +24,7 @@
|
|
|
24
24
|
#include "asdops/tensor.h"
|
|
25
25
|
|
|
26
26
|
#include "internal_kernel.h"
|
|
27
|
-
|
|
27
|
+
#include "param/attention_param.h"
|
|
28
28
|
#include "acl_rt.h"
|
|
29
29
|
|
|
30
30
|
#include <unordered_map>
|
|
@@ -49,6 +49,7 @@ class FlashAttentionScoreImpl : public InternelKernelImpl {
|
|
|
49
49
|
|
|
50
50
|
private:
|
|
51
51
|
uint64_t B, N, Q_S, KV_S, D, G, CORE_NUM;
|
|
52
|
+
int inner_precise, pre_tokens, next_tokens, sparse_mode;
|
|
52
53
|
bool BFLOAT16, BSH, ALIBI, AMASK;
|
|
53
54
|
const std::vector<Tensor *> *inputs_;
|
|
54
55
|
const std::vector<Tensor *> *outputs_;
|
|
@@ -18,16 +18,12 @@ typedef struct {
|
|
|
18
18
|
#pragma pack()
|
|
19
19
|
|
|
20
20
|
#define MAX_CORE_NUM 25
|
|
21
|
-
#define ATTENTION_DEBUG false
|
|
22
|
-
#define
|
|
23
|
-
|
|
24
|
-
#if PA
|
|
25
|
-
#define INC true
|
|
26
|
-
#define OP_NAME PagedAttention
|
|
27
|
-
#else
|
|
28
|
-
#define INC false
|
|
21
|
+
#define ATTENTION_DEBUG false // 开启时会对S/P写入调试数据
|
|
22
|
+
#define ROWMAX true
|
|
29
23
|
#define OP_NAME FlashAttentionScore
|
|
30
|
-
#
|
|
24
|
+
#define BUFFER_NUM 2 // 核间流水数,暂不支持修改
|
|
25
|
+
constexpr uint64_t WORKSPACE_MAX_SEQLEN = 16384; // max seqlen
|
|
26
|
+
constexpr uint64_t WORKSPACE_SIZE = 128 * WORKSPACE_MAX_SEQLEN;
|
|
31
27
|
|
|
32
28
|
#if BFLOAT16
|
|
33
29
|
#define TYPE_NAME _bf16
|
|
@@ -41,14 +37,16 @@ typedef struct {
|
|
|
41
37
|
#define LAYOUT_NAME _BNSD
|
|
42
38
|
#endif
|
|
43
39
|
|
|
44
|
-
#
|
|
45
|
-
#define
|
|
46
|
-
#
|
|
47
|
-
#define
|
|
48
|
-
|
|
49
|
-
#if INC
|
|
50
|
-
#define CORE_PER_KV_HEAD 4 // 增量推理时开启,每个kv_head切分成多少个任务计算
|
|
40
|
+
#if LOWER_TRIANGLE
|
|
41
|
+
#define TRI_NAME _tri
|
|
42
|
+
#else
|
|
43
|
+
#define TRI_NAME _full
|
|
51
44
|
#endif
|
|
45
|
+
|
|
46
|
+
#define CONCAT_(A, B, C, D, E) A##B##C##D##E
|
|
47
|
+
#define CONCAT(A, B, C, D, E) CONCAT_(A, B, C, D, E)
|
|
48
|
+
#define FUNC_NAME_AIC CONCAT(OP_NAME, TYPE_NAME, LAYOUT_NAME, TRI_NAME, _mix_aic)
|
|
49
|
+
#define FUNC_NAME_AIV CONCAT(OP_NAME, TYPE_NAME, LAYOUT_NAME, TRI_NAME, _mix_aiv)
|
|
52
50
|
|
|
53
51
|
// **************mask patten模式**************//
|
|
54
52
|
// 第一种:下三角,开启LOWER_TRIANGLE时会直接采用下三角,不依赖mask
|
|
@@ -63,9 +61,7 @@ typedef struct {
|
|
|
63
61
|
// 第四种:全矩阵,LOWER_TRIANGLE、BLOCK_SPARSE和AMASK如果全部关闭,则此attention采用全矩阵运算,不抑制S中的元素
|
|
64
62
|
// *******************************************//
|
|
65
63
|
|
|
66
|
-
constexpr uint64_t
|
|
67
|
-
constexpr uint64_t WORKSPACE_MAX_SEQLEN_BLOCK = WORKSPACE_MAX_SEQLEN / 16; // max seqlen: 10240
|
|
68
|
-
constexpr uint64_t WORKSPACE_SIZE = 128 * WORKSPACE_MAX_SEQLEN;
|
|
64
|
+
constexpr uint64_t WORKSPACE_MAX_SEQLEN_BLOCK = WORKSPACE_MAX_SEQLEN / 16;
|
|
69
65
|
constexpr uint64_t BUFFER_SIZE = MAX_CORE_NUM * WORKSPACE_SIZE * sizeof(uint16_t);
|
|
70
66
|
|
|
71
67
|
#endif
|
mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/gelu/tiling/gelu_tiling.h
CHANGED
|
@@ -26,15 +26,13 @@ struct GeLUTilingData {
|
|
|
26
26
|
uint32_t tailBlockTileNum{0};
|
|
27
27
|
};
|
|
28
28
|
static std::ostream &operator<<(std::ostream &os, const GeLUTilingData &dt) {
|
|
29
|
-
os << "blockDims:" << dt.blockDims
|
|
30
|
-
os << "totalLength:" << dt.totalLength
|
|
31
|
-
os << "blockLength:" << dt.blockLength
|
|
32
|
-
os << "tileLength:" << dt.tileLength
|
|
33
|
-
os << "tileNum:" << dt.tileNum
|
|
34
|
-
os << "tailBlockTileNum:" << dt.tailBlockTileNum
|
|
35
|
-
os << "tilingKey:" << dt.tilingKey
|
|
36
|
-
// os << "axisDim:" << dt.axisDim << std::endl;
|
|
37
|
-
// os << "splitNum:" << dt.splitNum << std::endl;
|
|
29
|
+
os << "blockDims:" << dt.blockDims;
|
|
30
|
+
os << ", totalLength:" << dt.totalLength;
|
|
31
|
+
os << ", blockLength:" << dt.blockLength;
|
|
32
|
+
os << ", tileLength:" << dt.tileLength;
|
|
33
|
+
os << ", tileNum:" << dt.tileNum;
|
|
34
|
+
os << ", tailBlockTileNum:" << dt.tailBlockTileNum;
|
|
35
|
+
os << ", tilingKey:" << dt.tilingKey;
|
|
38
36
|
return os;
|
|
39
37
|
}
|
|
40
38
|
#endif // MS_KERNELS_INTERNAL_ASCENDC_GELU_TILING_H
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright 2024 Huawei Technologies Co., Ltd
|
|
3
|
+
*
|
|
4
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
* you may not use this file except in compliance with the License.
|
|
6
|
+
* You may obtain a copy of the License at
|
|
7
|
+
*
|
|
8
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
*
|
|
10
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
* See the License for the specific language governing permissions and
|
|
14
|
+
* limitations under the License.
|
|
15
|
+
*/
|
|
16
|
+
#ifndef LCCL_WRAPPER_H_
|
|
17
|
+
#define LCCL_WRAPPER_H_
|
|
18
|
+
|
|
19
|
+
#include <memory>
|
|
20
|
+
#include "lccl.h"
|
|
21
|
+
|
|
22
|
+
#ifdef __cplusplus
|
|
23
|
+
extern "C" {
|
|
24
|
+
#endif
|
|
25
|
+
|
|
26
|
+
using namespace Lcal;
|
|
27
|
+
using LcclComm = std::shared_ptr<Lccl>;
|
|
28
|
+
enum class LcclResult {
|
|
29
|
+
LCAL_SUCCESS = 0,
|
|
30
|
+
LCAL_ERROR_NOT_INITIALIZED = -1,
|
|
31
|
+
LCAL_ERROR_ASDRT = -2,
|
|
32
|
+
LCAL_ERROR_PARA_CHECK_FAIL = -3,
|
|
33
|
+
LCAL_ERROR_INTERNAL = -4,
|
|
34
|
+
LCAL_ERROR_TIMEOUT = -5,
|
|
35
|
+
LCCL_ERROR_INIT_HCCL_FAILED = -6
|
|
36
|
+
};
|
|
37
|
+
|
|
38
|
+
extern LcclResult LcclCommInitRank(uint32_t nRanks, uint32_t rank, LcclComm *comm);
|
|
39
|
+
|
|
40
|
+
extern LcclResult LcclAllReduce(void *sendBuff, void *recvBuff, int64_t count, HcclDataType dataType,
|
|
41
|
+
HcclReduceOp op, aclrtStream stream);
|
|
42
|
+
|
|
43
|
+
extern LcclResult LcclReduceScatter(void *sendBuff, void *recvBuff, int64_t count, HcclDataType dataType,
|
|
44
|
+
HcclReduceOp op, aclrtStream stream);
|
|
45
|
+
|
|
46
|
+
extern LcclResult LcclAllGather(void *sendBuff, void *recvBuff, int64_t count, HcclDataType dataType, aclrtStream stream);
|
|
47
|
+
|
|
48
|
+
extern LcclResult LcclAll2All(void *sendBuff, void *recvBuff, int64_t count, HcclDataType dataType, aclrtStream stream);
|
|
49
|
+
|
|
50
|
+
extern LcclResult LcclBroadcast(void *buff, int64_t count, HcclDataType dataType, int32_t root, aclrtStream stream);
|
|
51
|
+
|
|
52
|
+
extern LcclResult LcclCommDestroy(LcclComm comm);
|
|
53
|
+
|
|
54
|
+
#ifdef __cplusplus
|
|
55
|
+
}
|
|
56
|
+
#endif
|
|
57
|
+
|
|
58
|
+
#endif // LCCL_WRAPPER_H_
|
mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul/matmul_impl.h
CHANGED
|
@@ -24,11 +24,13 @@
|
|
|
24
24
|
#include "asdops/tensor.h"
|
|
25
25
|
|
|
26
26
|
#include "utils.h"
|
|
27
|
-
#include "
|
|
28
|
-
#include "
|
|
29
|
-
#include "
|
|
27
|
+
#include "backend_param.h"
|
|
28
|
+
#include "param/matmul_ext_param.h"
|
|
29
|
+
#include "matmul_common/pp_matmul_info.h"
|
|
30
|
+
#include "matmul_common/tiling_utils.h"
|
|
31
|
+
#include "matmul_common/tiling_data.h"
|
|
32
|
+
#include "matmul_common/pp_matmul_common_tiling.h"
|
|
30
33
|
#include "tune_repo/utils.h"
|
|
31
|
-
|
|
32
34
|
#include "internal_kernel.h"
|
|
33
35
|
|
|
34
36
|
#include "acl_rt.h"
|
|
@@ -37,6 +39,8 @@
|
|
|
37
39
|
namespace mindspore {
|
|
38
40
|
namespace internal {
|
|
39
41
|
|
|
42
|
+
using namespace tiling;
|
|
43
|
+
|
|
40
44
|
enum class MatMulAlgo { PP = 0, LLM_CUSTOM = 1 };
|
|
41
45
|
|
|
42
46
|
class MatMulImpl : public InternelKernelImpl {
|
|
@@ -48,12 +52,15 @@ class MatMulImpl : public InternelKernelImpl {
|
|
|
48
52
|
int Launch() override;
|
|
49
53
|
size_t GetTilingBufSize() override;
|
|
50
54
|
int Tiling(HostRawBuf &tilingBuf) override;
|
|
51
|
-
|
|
52
|
-
int
|
|
55
|
+
void TilingBasicFromPp(uint32_t &blockDim, PpTilingData &tilingdata);
|
|
56
|
+
int TilingPp(HostRawBuf &tilingBuf, uint64_t tilingId, const uint32_t &blockDim, const PpTilingData &tilingdata);
|
|
57
|
+
int TilingLLMCustom(HostRawBuf &tilingBuf, uint64_t tilingId, const uint32_t &blockDim,
|
|
58
|
+
const PpTilingData &tilingdata);
|
|
53
59
|
std::vector<uint64_t> GetWorkSpaceSize() override;
|
|
54
60
|
int InferShape(const std::vector<DIMS> &input_shapes, std::vector<DIMS> &output_shapes) override;
|
|
55
|
-
bool
|
|
56
|
-
|
|
61
|
+
bool UseCustomMatMul();
|
|
62
|
+
void GetTunedKey();
|
|
63
|
+
void SetTunedValueCustom(const std::vector<int> &tuned_config);
|
|
57
64
|
|
|
58
65
|
private:
|
|
59
66
|
uint32_t m_, k_, n_;
|
|
@@ -61,7 +68,11 @@ class MatMulImpl : public InternelKernelImpl {
|
|
|
61
68
|
MatMulAlgo algo_ = MatMulAlgo::PP;
|
|
62
69
|
DeviceRawBuf tiling_addr_;
|
|
63
70
|
std::string soc_{"Ascend910B2"};
|
|
71
|
+
HardwareInfo hwInfo_;
|
|
72
|
+
CustomMatmulTilingData t_;
|
|
73
|
+
std::vector<int> tune_key_;
|
|
64
74
|
REPO tuningTable_;
|
|
75
|
+
REPO tuningTableCustom_;
|
|
65
76
|
TensorDType input_dtype_;
|
|
66
77
|
TensorDType output_dtype_;
|
|
67
78
|
int block_dim_ = 0;
|
|
@@ -17,6 +17,8 @@
|
|
|
17
17
|
#ifndef MATMUL_COMMMON_TILING_H
|
|
18
18
|
#define MATMUL_COMMMON_TILING_H
|
|
19
19
|
|
|
20
|
+
#include <cmath>
|
|
21
|
+
#include <iostream>
|
|
20
22
|
#include "pp_matmul_info.h"
|
|
21
23
|
|
|
22
24
|
namespace mindspore {
|
|
@@ -100,9 +102,12 @@ inline __attribute__((always_inline)) float CostFunc(const HardwareType &hwInfo,
|
|
|
100
102
|
template <bool PRI_FLAG, typename OpShareType, typename TilingType, typename HardwareType, typename MatMulInfoType>
|
|
101
103
|
void TilingFunc(OpShareType &opShape, TilingType &tilingParam, const HardwareType &hwInfo, const MatMulInfoType &mmInfo,
|
|
102
104
|
bool compressFlag = false, const uint32_t tilingN = 1) {
|
|
105
|
+
using namespace std;
|
|
103
106
|
float costMin = 1;
|
|
104
|
-
uint32_t
|
|
105
|
-
uint32_t
|
|
107
|
+
const uint32_t CONST_16 = 16;
|
|
108
|
+
uint32_t roundBase = pow(2, ceil(log(CeilDiv(PRI_FLAG ? opShape.n : opShape.m, CONST_16)))) * CONST_16;
|
|
109
|
+
uint32_t priAxes = RoundUp(PRI_FLAG ? opShape.m : opShape.n, CONST_16);
|
|
110
|
+
uint32_t axes = RoundUp(PRI_FLAG ? opShape.n : opShape.m, roundBase);
|
|
106
111
|
float axes0Max = static_cast<float>(AXES_ALIGN_SIZE) / mmInfo.inDtype;
|
|
107
112
|
|
|
108
113
|
uint32_t n0TilingInit =
|
|
@@ -129,10 +134,15 @@ void TilingFunc(OpShareType &opShape, TilingType &tilingParam, const HardwareTyp
|
|
|
129
134
|
}
|
|
130
135
|
opShape.m0 = PRI_FLAG ? priAxes0 : axes0;
|
|
131
136
|
opShape.n0 = PRI_FLAG ? axes0 : priAxes0;
|
|
137
|
+
if ((mmInfo.qkv_n0 + mmInfo.qkv_n1 + mmInfo.qkv_n2 != 0) &&
|
|
138
|
+
(mmInfo.qkv_n0 < opShape.n0 || mmInfo.qkv_n1 < opShape.n0 ||
|
|
139
|
+
(mmInfo.qkv_n2 < opShape.n0 && mmInfo.qkv_n2 > 1))) {
|
|
140
|
+
continue;
|
|
141
|
+
}
|
|
132
142
|
float cost = CostFunc<HardwareType, OpShareType>(hwInfo, opShape);
|
|
133
143
|
if (cost < costMin) {
|
|
134
144
|
costMin = cost;
|
|
135
|
-
tilingParam.SetBaseOp(hwInfo.coreNum, opShape.m0, opShape.n0);
|
|
145
|
+
tilingParam.SetBaseOp(hwInfo.coreNum, opShape.m0, opShape.n0, mmInfo.qkv_n0, mmInfo.qkv_n1, mmInfo.qkv_n2);
|
|
136
146
|
}
|
|
137
147
|
}
|
|
138
148
|
}
|
|
@@ -140,7 +150,7 @@ void TilingFunc(OpShareType &opShape, TilingType &tilingParam, const HardwareTyp
|
|
|
140
150
|
|
|
141
151
|
template <typename PpTilingDataType>
|
|
142
152
|
uint32_t Swizzl(PpTilingDataType &tilingData) {
|
|
143
|
-
uint32_t
|
|
153
|
+
uint32_t swizzlDirect = 0;
|
|
144
154
|
uint32_t swizzlCount = 1;
|
|
145
155
|
float m0 = tilingData.opShape.m0;
|
|
146
156
|
float n0 = tilingData.opShape.n0;
|
|
@@ -154,14 +164,14 @@ uint32_t Swizzl(PpTilingDataType &tilingData) {
|
|
|
154
164
|
float cost;
|
|
155
165
|
// B0 + A < A0 + B
|
|
156
166
|
if (i * n0 + m < m0 * c + n) {
|
|
157
|
-
|
|
167
|
+
swizzlDirect = 1; // Nz
|
|
158
168
|
cost = n0 * i + m0 * c;
|
|
159
169
|
if (cost <= mincost) {
|
|
160
170
|
mincost = cost;
|
|
161
171
|
swizzlCount = i;
|
|
162
172
|
}
|
|
163
173
|
} else {
|
|
164
|
-
|
|
174
|
+
swizzlDirect = 0; // Zn
|
|
165
175
|
cost = m0 * i + n0 * c;
|
|
166
176
|
if (cost < mincost) {
|
|
167
177
|
mincost = cost;
|
|
@@ -169,9 +179,9 @@ uint32_t Swizzl(PpTilingDataType &tilingData) {
|
|
|
169
179
|
}
|
|
170
180
|
}
|
|
171
181
|
}
|
|
172
|
-
tilingData.
|
|
182
|
+
tilingData.swizzlDirect = swizzlDirect;
|
|
173
183
|
tilingData.swizzlCount = swizzlCount;
|
|
174
|
-
return
|
|
184
|
+
return swizzlDirect;
|
|
175
185
|
}
|
|
176
186
|
|
|
177
187
|
} // namespace tiling
|
|
@@ -26,7 +26,10 @@ namespace internal {
|
|
|
26
26
|
namespace tiling {
|
|
27
27
|
struct MatMulInfo {
|
|
28
28
|
uint32_t batchSize{0};
|
|
29
|
-
uint32_t m{0};
|
|
29
|
+
uint32_t m{0}; // 实际输入的 m
|
|
30
|
+
uint32_t qkv_n0{0};
|
|
31
|
+
uint32_t qkv_n1{0};
|
|
32
|
+
uint32_t qkv_n2{0};
|
|
30
33
|
uint32_t n{0}; // 实际输入的 n
|
|
31
34
|
uint32_t k{0}; // 实际输入的 k
|
|
32
35
|
bool transA{0}; // false: 0, true: 1
|
|
@@ -57,12 +60,12 @@ struct PpTilingData {
|
|
|
57
60
|
uint32_t swizzlCount{1};
|
|
58
61
|
uint32_t tilingKey{0};
|
|
59
62
|
uint32_t blockDim{1};
|
|
60
|
-
uint32_t
|
|
63
|
+
uint32_t swizzlDirect{0};
|
|
61
64
|
uint32_t splitk{0};
|
|
62
65
|
|
|
63
66
|
void SetBaseShape(uint32_t batchSize, uint32_t m, uint32_t k, uint32_t n);
|
|
64
|
-
void SetBaseOp(uint32_t coreNum, uint32_t mBase, uint32_t nBase);
|
|
65
|
-
void SetTilingKey(const MatMulInfo &mmInfo, uint32_t
|
|
67
|
+
void SetBaseOp(uint32_t coreNum, uint32_t mBase, uint32_t nBase, uint32_t qkv_n0, uint32_t qkv_n1, uint32_t qkv_n2);
|
|
68
|
+
void SetTilingKey(const MatMulInfo &mmInfo, uint32_t swizzlDirect, uint32_t enSplitK);
|
|
66
69
|
uint32_t End(const MatMulInfo &mmInfo);
|
|
67
70
|
};
|
|
68
71
|
} // namespace tiling
|
|
@@ -18,7 +18,6 @@
|
|
|
18
18
|
#define MATMUL_TILING_DATA_H
|
|
19
19
|
|
|
20
20
|
#include <stdint.h>
|
|
21
|
-
#include <algorithm>
|
|
22
21
|
|
|
23
22
|
namespace mindspore {
|
|
24
23
|
namespace internal {
|
|
@@ -38,8 +37,16 @@ struct PpMatmulTilingData {
|
|
|
38
37
|
uint32_t swizzlCount{0};
|
|
39
38
|
uint32_t tilingKey{0};
|
|
40
39
|
uint32_t blockDim{1};
|
|
41
|
-
uint32_t
|
|
40
|
+
uint32_t swizzlDirect{0};
|
|
42
41
|
uint32_t splitk{0};
|
|
42
|
+
uint32_t enShuffleK{0};
|
|
43
|
+
uint32_t unused0{0};
|
|
44
|
+
uint32_t unused1{0};
|
|
45
|
+
uint32_t unused2{0};
|
|
46
|
+
uint32_t unused3{0};
|
|
47
|
+
uint32_t unused4{0};
|
|
48
|
+
uint32_t unused5{0};
|
|
49
|
+
uint32_t unused6{0};
|
|
43
50
|
uint32_t tilingId{0};
|
|
44
51
|
};
|
|
45
52
|
|
|
@@ -60,18 +67,49 @@ struct CustomMatmulTilingData {
|
|
|
60
67
|
uint32_t BaseKNum{0};
|
|
61
68
|
uint32_t BaseNNum{0};
|
|
62
69
|
uint32_t MmadM{0};
|
|
70
|
+
uint32_t MmadK{0};
|
|
71
|
+
uint32_t MmadN{0};
|
|
72
|
+
uint32_t fractal_k_num{0};
|
|
73
|
+
uint32_t FractalKInBlockNum{0};
|
|
74
|
+
uint32_t PartKInMmad{0};
|
|
75
|
+
uint32_t TransA{0};
|
|
76
|
+
uint32_t TransB{0};
|
|
77
|
+
uint32_t shuffleFlag{0};
|
|
63
78
|
uint32_t tilingId{0};
|
|
79
|
+
};
|
|
80
|
+
|
|
81
|
+
constexpr size_t maxTilingBufSize = sizeof(CustomMatmulTilingData);
|
|
82
|
+
|
|
83
|
+
struct MatmulStridedSliceFusionTilingData {
|
|
84
|
+
uint32_t tilingId{0};
|
|
85
|
+
uint32_t BlockDimM{0};
|
|
86
|
+
uint32_t BlockDimN{0};
|
|
87
|
+
uint32_t BlockTotal{0};
|
|
88
|
+
uint32_t M{0};
|
|
89
|
+
uint32_t K{0};
|
|
90
|
+
uint32_t N{0};
|
|
91
|
+
uint32_t N0{0};
|
|
92
|
+
uint32_t N1{0};
|
|
93
|
+
uint32_t N2{0};
|
|
94
|
+
uint32_t BaseM{0};
|
|
95
|
+
uint32_t BaseK{0};
|
|
96
|
+
uint32_t BaseN{0};
|
|
97
|
+
uint32_t BlockLenM{0};
|
|
98
|
+
uint32_t BlockLenK{0};
|
|
99
|
+
uint32_t BlockLenN{0};
|
|
100
|
+
uint32_t BaseMNum{0};
|
|
101
|
+
uint32_t BaseKNum{0};
|
|
102
|
+
uint32_t BaseNNum{0};
|
|
103
|
+
uint32_t MmadM{0};
|
|
64
104
|
uint32_t MmadK{0};
|
|
65
105
|
uint32_t MmadN{0};
|
|
66
|
-
uint32_t FractalKNum{0};
|
|
67
106
|
uint32_t FractalKInBlockNum{0};
|
|
68
|
-
uint32_t
|
|
107
|
+
uint32_t PartKInMmad{2};
|
|
69
108
|
uint32_t TransA{0};
|
|
70
109
|
uint32_t TransB{1};
|
|
110
|
+
uint32_t shuffleFlag{0};
|
|
71
111
|
};
|
|
72
112
|
|
|
73
|
-
constexpr size_t maxTilingBufSize = std::max(sizeof(PpMatmulTilingData), sizeof(CustomMatmulTilingData));
|
|
74
|
-
|
|
75
113
|
} // namespace tiling
|
|
76
114
|
} // namespace internal
|
|
77
115
|
} // namespace mindspore
|