@fugood/llama.node 0.3.16 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +6 -1
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +44 -2
- package/lib/index.js +132 -1
- package/lib/index.ts +203 -3
- package/package.json +2 -1
- package/src/EmbeddingWorker.cpp +1 -1
- package/src/LlamaCompletionWorker.cpp +374 -19
- package/src/LlamaCompletionWorker.h +31 -10
- package/src/LlamaContext.cpp +216 -7
- package/src/LlamaContext.h +12 -0
- package/src/common.hpp +15 -0
- package/src/llama.cpp/.github/workflows/build-linux-cross.yml +233 -0
- package/src/llama.cpp/.github/workflows/build.yml +89 -767
- package/src/llama.cpp/.github/workflows/docker.yml +9 -6
- package/src/llama.cpp/.github/workflows/release.yml +716 -0
- package/src/llama.cpp/.github/workflows/server.yml +19 -23
- package/src/llama.cpp/CMakeLists.txt +11 -1
- package/src/llama.cpp/cmake/build-info.cmake +8 -2
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
- package/src/llama.cpp/common/CMakeLists.txt +35 -4
- package/src/llama.cpp/common/arg.cpp +844 -121
- package/src/llama.cpp/common/arg.h +9 -0
- package/src/llama.cpp/common/chat.cpp +129 -107
- package/src/llama.cpp/common/chat.h +2 -0
- package/src/llama.cpp/common/common.cpp +64 -518
- package/src/llama.cpp/common/common.h +35 -45
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
- package/src/llama.cpp/common/llguidance.cpp +31 -47
- package/src/llama.cpp/common/minja/chat-template.hpp +23 -11
- package/src/llama.cpp/common/minja/minja.hpp +186 -127
- package/src/llama.cpp/common/regex-partial.cpp +204 -0
- package/src/llama.cpp/common/regex-partial.h +56 -0
- package/src/llama.cpp/common/sampling.cpp +60 -50
- package/src/llama.cpp/docs/build.md +122 -7
- package/src/llama.cpp/examples/CMakeLists.txt +2 -32
- package/src/llama.cpp/examples/batched/batched.cpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +9 -12
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/parallel/parallel.cpp +89 -15
- package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
- package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
- package/src/llama.cpp/examples/sycl/build.sh +2 -2
- package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
- package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/training/finetune.cpp +96 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +35 -2
- package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
- package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
- package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
- package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
- package/src/llama.cpp/ggml/include/ggml.h +76 -106
- package/src/llama.cpp/ggml/src/CMakeLists.txt +11 -8
- package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
- package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +66 -33
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -194
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1060 -410
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1008 -13533
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +31 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +90 -12
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +266 -72
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1034 -88
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8796 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +252 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +106 -14
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -262
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +307 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +125 -45
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +10 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +239 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +9 -307
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +944 -438
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +507 -411
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
- package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +83 -49
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1278 -282
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +32 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +133 -30
- package/src/llama.cpp/ggml/src/ggml.c +170 -265
- package/src/llama.cpp/ggml/src/gguf.cpp +34 -33
- package/src/llama.cpp/include/llama.h +82 -22
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +5 -3
- package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
- package/src/llama.cpp/scripts/xxd.cmake +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +4 -2
- package/src/llama.cpp/src/llama-adapter.cpp +43 -1
- package/src/llama.cpp/src/llama-arch.cpp +163 -17
- package/src/llama.cpp/src/llama-arch.h +16 -0
- package/src/llama.cpp/src/llama-batch.cpp +5 -1
- package/src/llama.cpp/src/llama-batch.h +2 -1
- package/src/llama.cpp/src/llama-chat.cpp +91 -16
- package/src/llama.cpp/src/llama-chat.h +7 -2
- package/src/llama.cpp/src/llama-context.cpp +479 -575
- package/src/llama.cpp/src/llama-context.h +44 -33
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +209 -157
- package/src/llama.cpp/src/llama-graph.h +38 -14
- package/src/llama.cpp/src/llama-hparams.h +13 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +1604 -543
- package/src/llama.cpp/src/llama-kv-cache.h +283 -171
- package/src/llama.cpp/src/llama-memory.h +12 -2
- package/src/llama.cpp/src/llama-mmap.cpp +1 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +34 -20
- package/src/llama.cpp/src/llama-model-loader.h +5 -3
- package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
- package/src/llama.cpp/src/llama-model-saver.h +37 -0
- package/src/llama.cpp/src/llama-model.cpp +1803 -330
- package/src/llama.cpp/src/llama-model.h +21 -2
- package/src/llama.cpp/src/llama-quant.cpp +33 -10
- package/src/llama.cpp/src/llama-sampling.cpp +25 -7
- package/src/llama.cpp/src/llama-vocab.cpp +86 -10
- package/src/llama.cpp/src/llama-vocab.h +6 -0
- package/src/llama.cpp/src/llama.cpp +15 -1
- package/src/llama.cpp/tests/CMakeLists.txt +52 -31
- package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
- package/src/llama.cpp/tests/test-backend-ops.cpp +189 -90
- package/src/llama.cpp/tests/test-chat-template.cpp +26 -6
- package/src/llama.cpp/tests/test-chat.cpp +15 -3
- package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
- package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
- package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
- package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
- package/src/llama.cpp/tests/test-opt.cpp +33 -21
- package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
- package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
- package/src/llama.cpp/tests/test-sampling.cpp +1 -1
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
- package/src/llama.cpp/tools/CMakeLists.txt +39 -0
- package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +3 -3
- package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +1 -1
- package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +15 -16
- package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
- package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +623 -274
- package/src/llama.cpp/{examples → tools}/main/main.cpp +22 -14
- package/src/llama.cpp/tools/mtmd/CMakeLists.txt +47 -0
- package/src/llama.cpp/tools/mtmd/clip-impl.h +365 -0
- package/src/llama.cpp/tools/mtmd/clip.cpp +3646 -0
- package/src/llama.cpp/tools/mtmd/clip.h +99 -0
- package/src/llama.cpp/tools/mtmd/deprecation-warning.cpp +22 -0
- package/src/llama.cpp/tools/mtmd/mtmd-cli.cpp +370 -0
- package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
- package/src/llama.cpp/tools/mtmd/mtmd.cpp +678 -0
- package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
- package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +21 -5
- package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +53 -3
- package/src/llama.cpp/tools/rpc/CMakeLists.txt +4 -0
- package/src/llama.cpp/tools/rpc/rpc-server.cpp +322 -0
- package/src/llama.cpp/tools/run/CMakeLists.txt +16 -0
- package/src/llama.cpp/{examples → tools}/run/run.cpp +30 -30
- package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
- package/src/llama.cpp/{examples → tools}/server/httplib.h +313 -247
- package/src/llama.cpp/{examples → tools}/server/server.cpp +529 -215
- package/src/llama.cpp/{examples → tools}/server/utils.hpp +427 -6
- package/src/llama.cpp/{examples → tools}/tts/tts.cpp +6 -9
- package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/infill.cpp +0 -590
- package/src/llama.cpp/examples/llava/CMakeLists.txt +0 -66
- package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
- package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
- package/src/llama.cpp/examples/llava/clip.cpp +0 -3206
- package/src/llama.cpp/examples/llava/clip.h +0 -118
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
- package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
- package/src/llama.cpp/examples/llava/llava.cpp +0 -574
- package/src/llama.cpp/examples/llava/llava.h +0 -49
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +0 -584
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/rpc/CMakeLists.txt +0 -2
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +0 -171
- package/src/llama.cpp/examples/run/CMakeLists.txt +0 -5
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
|
@@ -55,6 +55,7 @@
|
|
|
55
55
|
|
|
56
56
|
#include <atomic>
|
|
57
57
|
#include <array>
|
|
58
|
+
#include <type_traits>
|
|
58
59
|
|
|
59
60
|
#ifdef _MSC_VER
|
|
60
61
|
#define NOINLINE __declspec(noinline)
|
|
@@ -1053,6 +1054,493 @@ class tinyBLAS_Q0_AVX {
|
|
|
1053
1054
|
} \
|
|
1054
1055
|
} \
|
|
1055
1056
|
|
|
1057
|
+
template <typename TA, typename TB, typename TC>
|
|
1058
|
+
class tinyBLAS_BF16_PPC {
|
|
1059
|
+
public:
|
|
1060
|
+
tinyBLAS_BF16_PPC(int64_t k,
|
|
1061
|
+
const TA *A, int64_t lda,
|
|
1062
|
+
const TB *B, int64_t ldb,
|
|
1063
|
+
TC *C, int64_t ldc,
|
|
1064
|
+
int ith, int nth)
|
|
1065
|
+
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
|
1066
|
+
}
|
|
1067
|
+
|
|
1068
|
+
void matmul(int64_t m, int64_t n) {
|
|
1069
|
+
mnpack(0, m, 0, n);
|
|
1070
|
+
}
|
|
1071
|
+
|
|
1072
|
+
private:
|
|
1073
|
+
void vector_permute_store(vec_t *c, int numVec, unsigned char *vecOffset) {
|
|
1074
|
+
vec_t t[8], s[8];
|
|
1075
|
+
vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
|
|
1076
|
+
vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
|
|
1077
|
+
vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
|
1078
|
+
vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
|
1079
|
+
|
|
1080
|
+
if (numVec == 2) {
|
|
1081
|
+
t[0] = vec_perm(c[0], c[1], swiz1);
|
|
1082
|
+
t[1] = vec_perm(c[2], c[3], swiz1);
|
|
1083
|
+
s[0] = vec_perm(t[0], t[1], swiz3);
|
|
1084
|
+
s[1] = vec_perm(t[0], t[1], swiz4);
|
|
1085
|
+
vec_xst(s[0], 0, (vec_t*)vecOffset);
|
|
1086
|
+
vec_xst(s[1], 0, (vec_t*)(vecOffset + 16));
|
|
1087
|
+
} else if (numVec == 4) {
|
|
1088
|
+
t[0] = vec_perm(c[0], c[1], swiz1);
|
|
1089
|
+
t[1] = vec_perm(c[0], c[1], swiz2);
|
|
1090
|
+
t[2] = vec_perm(c[2], c[3], swiz1);
|
|
1091
|
+
t[3] = vec_perm(c[2], c[3], swiz2);
|
|
1092
|
+
s[0] = vec_perm(t[0], t[2], swiz3);
|
|
1093
|
+
s[1] = vec_perm(t[0], t[2], swiz4);
|
|
1094
|
+
s[2] = vec_perm(t[1], t[3], swiz3);
|
|
1095
|
+
s[3] = vec_perm(t[1], t[3], swiz4);
|
|
1096
|
+
for (int i = 0; i < 4; ++i)
|
|
1097
|
+
vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
|
|
1098
|
+
} else if (numVec == 8) {
|
|
1099
|
+
for (int i = 0; i < 4; i += 2) {
|
|
1100
|
+
t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
|
|
1101
|
+
t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
|
|
1102
|
+
}
|
|
1103
|
+
for (int i = 4; i < 8; i += 2) {
|
|
1104
|
+
t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
|
|
1105
|
+
t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
|
|
1106
|
+
}
|
|
1107
|
+
s[0] = vec_perm(t[0], t[2], swiz3);
|
|
1108
|
+
s[1] = vec_perm(t[0], t[2], swiz4);
|
|
1109
|
+
s[2] = vec_perm(t[1], t[3], swiz3);
|
|
1110
|
+
s[3] = vec_perm(t[1], t[3], swiz4);
|
|
1111
|
+
s[4] = vec_perm(t[4], t[6], swiz3);
|
|
1112
|
+
s[5] = vec_perm(t[4], t[6], swiz4);
|
|
1113
|
+
s[6] = vec_perm(t[5], t[7], swiz3);
|
|
1114
|
+
s[7] = vec_perm(t[5], t[7], swiz4);
|
|
1115
|
+
for (int i = 0; i < 8; ++i)
|
|
1116
|
+
vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
|
|
1117
|
+
}
|
|
1118
|
+
}
|
|
1119
|
+
|
|
1120
|
+
void packNormal(const TA* a, int64_t lda, int rows, int cols, unsigned char* vec) {
|
|
1121
|
+
int64_t i, j;
|
|
1122
|
+
TA *aoffset = NULL;
|
|
1123
|
+
unsigned char *vecOffset = NULL;
|
|
1124
|
+
TA * aoffsets[8];
|
|
1125
|
+
vector unsigned char c_arr[8];
|
|
1126
|
+
aoffset = const_cast<TA*>(a);
|
|
1127
|
+
vecOffset = vec;
|
|
1128
|
+
j = (rows >> 3);
|
|
1129
|
+
if (j > 0) {
|
|
1130
|
+
do {
|
|
1131
|
+
if (cols == 4) {
|
|
1132
|
+
aoffsets[0] = aoffset;
|
|
1133
|
+
for (int it = 1; it < 4; ++it)
|
|
1134
|
+
aoffsets[it] = aoffsets[it-1] + lda;
|
|
1135
|
+
aoffset += 4 * lda;
|
|
1136
|
+
for (int i = 0; i < 4; ++i)
|
|
1137
|
+
c_arr[i] = vec_xl(0, (vector unsigned char*)aoffsets[i]);
|
|
1138
|
+
vector_permute_store(c_arr, 4, vecOffset);
|
|
1139
|
+
for (int i = 0; i<4; i++)
|
|
1140
|
+
aoffsets[i] = aoffsets[i]+lda;
|
|
1141
|
+
vecOffset +=64;
|
|
1142
|
+
}
|
|
1143
|
+
i = (cols >> 3);
|
|
1144
|
+
if (i > 0) {
|
|
1145
|
+
aoffsets[0] = aoffset;
|
|
1146
|
+
for (int it = 1; it < 8; ++it) {
|
|
1147
|
+
aoffsets[it] = aoffsets[it-1] + lda;
|
|
1148
|
+
}
|
|
1149
|
+
aoffset += 8 * lda;
|
|
1150
|
+
do {
|
|
1151
|
+
for (int it = 0; it < 8; ++it)
|
|
1152
|
+
c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
|
|
1153
|
+
vector_permute_store(c_arr, 8, vecOffset);
|
|
1154
|
+
for (int it = 0; it < 8; ++it)
|
|
1155
|
+
aoffsets[it] = aoffsets[it] + 8*lda;
|
|
1156
|
+
vecOffset += 128;
|
|
1157
|
+
i--;
|
|
1158
|
+
} while(i > 0);
|
|
1159
|
+
}
|
|
1160
|
+
j--;
|
|
1161
|
+
} while(j > 0);
|
|
1162
|
+
}
|
|
1163
|
+
if (rows & 4) {
|
|
1164
|
+
aoffsets[0] = aoffset;
|
|
1165
|
+
for (int it = 1; it < 4; ++it)
|
|
1166
|
+
aoffsets[it] = aoffsets[it-1] + lda;
|
|
1167
|
+
aoffset += 4 * lda;
|
|
1168
|
+
if (cols == 4) {
|
|
1169
|
+
for (int it = 0; it < 4; ++it)
|
|
1170
|
+
c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
|
|
1171
|
+
vector_permute_store(c_arr, 2, vecOffset);
|
|
1172
|
+
for (int it = 0; it< 4; it++)
|
|
1173
|
+
aoffsets[it] = aoffsets[it] + lda;
|
|
1174
|
+
vecOffset += 32;
|
|
1175
|
+
}
|
|
1176
|
+
i = (cols >> 3);
|
|
1177
|
+
if (i > 0) {
|
|
1178
|
+
do {
|
|
1179
|
+
for (int it = 0; it < 4; ++it)
|
|
1180
|
+
c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
|
|
1181
|
+
vector_permute_store(c_arr, 4, vecOffset);
|
|
1182
|
+
for (int it = 0; it< 4; it++)
|
|
1183
|
+
aoffsets[it] = aoffsets[it] + 8*lda;
|
|
1184
|
+
vecOffset += 64;
|
|
1185
|
+
i--;
|
|
1186
|
+
} while(i > 0);
|
|
1187
|
+
}
|
|
1188
|
+
}
|
|
1189
|
+
if (rows & 3) {
|
|
1190
|
+
aoffsets[0] = aoffset;
|
|
1191
|
+
for (int it = 1; it < 4; ++it)
|
|
1192
|
+
aoffsets[it] = aoffsets[it-1] + lda;
|
|
1193
|
+
if (cols == 4) {
|
|
1194
|
+
switch(rows) {
|
|
1195
|
+
case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
|
|
1196
|
+
case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
|
|
1197
|
+
case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
|
|
1198
|
+
break;
|
|
1199
|
+
}
|
|
1200
|
+
vector_permute_store(c_arr, 2, vecOffset);
|
|
1201
|
+
for (int it = 0; it< 4; it++)
|
|
1202
|
+
aoffsets[it] = aoffsets[it] + lda;
|
|
1203
|
+
vecOffset += 32;
|
|
1204
|
+
}
|
|
1205
|
+
i = (cols >> 3);
|
|
1206
|
+
if (i > 0) {
|
|
1207
|
+
do {
|
|
1208
|
+
switch(rows) {
|
|
1209
|
+
case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
|
|
1210
|
+
case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
|
|
1211
|
+
case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
|
|
1212
|
+
break;
|
|
1213
|
+
}
|
|
1214
|
+
vector_permute_store(c_arr, 4, vecOffset);
|
|
1215
|
+
for (int it = 0; it <4; it++)
|
|
1216
|
+
aoffsets[it] = aoffsets[it] + 8* lda;
|
|
1217
|
+
vecOffset += 64;
|
|
1218
|
+
i--;
|
|
1219
|
+
} while(i > 0);
|
|
1220
|
+
}
|
|
1221
|
+
}
|
|
1222
|
+
}
|
|
1223
|
+
|
|
1224
|
+
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
|
1225
|
+
int64_t mc, nc, mp, np;
|
|
1226
|
+
int m_rem = MIN(m - m0, 8);
|
|
1227
|
+
int n_rem = MIN(n - n0, 8);
|
|
1228
|
+
|
|
1229
|
+
if (m_rem >= 8 && n_rem >= 8) {
|
|
1230
|
+
mc = 8;
|
|
1231
|
+
nc = 8;
|
|
1232
|
+
gemm<8,8>(m0, m, n0, n);
|
|
1233
|
+
} else if (m_rem >= 4 && n_rem >= 8) {
|
|
1234
|
+
mc = 4;
|
|
1235
|
+
nc = 8;
|
|
1236
|
+
gemm<4,8>(m0, m, n0, n);
|
|
1237
|
+
} else if (m_rem >=8 && n_rem >=4){
|
|
1238
|
+
mc = 8;
|
|
1239
|
+
nc = 4;
|
|
1240
|
+
gemm<8,4>(m0, m, n0, n);
|
|
1241
|
+
} else if ((m_rem < 4) && (n_rem >= 8)) {
|
|
1242
|
+
nc = 8;
|
|
1243
|
+
switch(m_rem) {
|
|
1244
|
+
case 1:
|
|
1245
|
+
mc = 1;
|
|
1246
|
+
gemm_Mx8<1>(m0, m, n0, n);
|
|
1247
|
+
break;
|
|
1248
|
+
case 2:
|
|
1249
|
+
mc = 2;
|
|
1250
|
+
gemm_Mx8<2>(m0, m, n0, n);
|
|
1251
|
+
break;
|
|
1252
|
+
case 3:
|
|
1253
|
+
mc = 3;
|
|
1254
|
+
gemm_Mx8<3>(m0, m, n0, n);
|
|
1255
|
+
break;
|
|
1256
|
+
default:
|
|
1257
|
+
return;
|
|
1258
|
+
}
|
|
1259
|
+
} else if (m_rem >= 4 && n_rem >= 4) {
|
|
1260
|
+
mc = 4;
|
|
1261
|
+
nc = 4;
|
|
1262
|
+
gemm_small<4, 4>(m0, m, n0, n);
|
|
1263
|
+
} else if ((m_rem > 4) && (n_rem < 4)) {
|
|
1264
|
+
mc = 4;
|
|
1265
|
+
switch(n_rem) {
|
|
1266
|
+
case 1:
|
|
1267
|
+
nc = 1;
|
|
1268
|
+
gemm_small<4, 1>(m0, m, n0, n);
|
|
1269
|
+
break;
|
|
1270
|
+
case 2:
|
|
1271
|
+
nc = 2;
|
|
1272
|
+
gemm_small<4, 2>(m0, m, n0, n);
|
|
1273
|
+
break;
|
|
1274
|
+
case 3:
|
|
1275
|
+
nc = 3;
|
|
1276
|
+
gemm_small<4, 3>(m0, m, n0, n);
|
|
1277
|
+
break;
|
|
1278
|
+
|
|
1279
|
+
default:
|
|
1280
|
+
return;
|
|
1281
|
+
}
|
|
1282
|
+
} else {
|
|
1283
|
+
switch((m_rem << 4) | n_rem) {
|
|
1284
|
+
case 0x43:
|
|
1285
|
+
mc = 4;
|
|
1286
|
+
nc = 3;
|
|
1287
|
+
gemm_small<4, 3>(m0, m, n0, n);
|
|
1288
|
+
break;
|
|
1289
|
+
case 0x42:
|
|
1290
|
+
mc = 4;
|
|
1291
|
+
nc = 2;
|
|
1292
|
+
gemm_small<4, 2>(m0, m, n0, n);
|
|
1293
|
+
break;
|
|
1294
|
+
case 0x41:
|
|
1295
|
+
mc = 4;
|
|
1296
|
+
nc = 1;
|
|
1297
|
+
gemm_small<4, 1>(m0, m, n0, n);
|
|
1298
|
+
break;
|
|
1299
|
+
case 0x34:
|
|
1300
|
+
mc = 3;
|
|
1301
|
+
nc = 4;
|
|
1302
|
+
gemm_small<3, 4>(m0, m, n0, n);
|
|
1303
|
+
break;
|
|
1304
|
+
case 0x33:
|
|
1305
|
+
mc = 3;
|
|
1306
|
+
nc = 3;
|
|
1307
|
+
gemm_small<3, 3>(m0, m, n0, n);
|
|
1308
|
+
break;
|
|
1309
|
+
case 0x32:
|
|
1310
|
+
mc = 3;
|
|
1311
|
+
nc = 2;
|
|
1312
|
+
gemm_small<3, 2>(m0, m, n0, n);
|
|
1313
|
+
break;
|
|
1314
|
+
case 0x31:
|
|
1315
|
+
mc = 3;
|
|
1316
|
+
nc = 1;
|
|
1317
|
+
gemm_small<3, 1>(m0, m, n0, n);
|
|
1318
|
+
break;
|
|
1319
|
+
case 0x24:
|
|
1320
|
+
mc = 2;
|
|
1321
|
+
nc = 4;
|
|
1322
|
+
gemm_small<2,4>(m0, m, n0, n);
|
|
1323
|
+
break;
|
|
1324
|
+
case 0x23:
|
|
1325
|
+
mc = 2;
|
|
1326
|
+
nc = 3;
|
|
1327
|
+
gemm_small<2, 3>(m0, m, n0, n);
|
|
1328
|
+
break;
|
|
1329
|
+
case 0x22:
|
|
1330
|
+
mc = 2;
|
|
1331
|
+
nc = 2;
|
|
1332
|
+
gemm_small<2, 2>(m0, m, n0, n);
|
|
1333
|
+
break;
|
|
1334
|
+
case 0x21:
|
|
1335
|
+
mc = 2;
|
|
1336
|
+
nc = 1;
|
|
1337
|
+
gemm_small<2, 1>(m0, m, n0, n);
|
|
1338
|
+
break;
|
|
1339
|
+
case 0x14:
|
|
1340
|
+
mc = 1;
|
|
1341
|
+
nc = 4;
|
|
1342
|
+
gemm_small<1, 4>(m0, m, n0, n);
|
|
1343
|
+
break;
|
|
1344
|
+
case 0x13:
|
|
1345
|
+
mc = 1;
|
|
1346
|
+
nc = 3;
|
|
1347
|
+
gemm_small<1, 3>(m0, m, n0, n);
|
|
1348
|
+
break;
|
|
1349
|
+
case 0x12:
|
|
1350
|
+
mc = 1;
|
|
1351
|
+
nc = 2;
|
|
1352
|
+
gemm_small<1, 2>(m0, m, n0, n);
|
|
1353
|
+
break;
|
|
1354
|
+
case 0x11:
|
|
1355
|
+
mc = 1;
|
|
1356
|
+
nc = 1;
|
|
1357
|
+
gemm_small<1, 1>(m0, m, n0, n);
|
|
1358
|
+
break;
|
|
1359
|
+
default:
|
|
1360
|
+
return;
|
|
1361
|
+
}
|
|
1362
|
+
}
|
|
1363
|
+
mp = m0 + (m - m0) / mc * mc;
|
|
1364
|
+
np = n0 + (n - n0) / nc * nc;
|
|
1365
|
+
mnpack(mp, m, n0, np);
|
|
1366
|
+
mnpack(m0, m, np, n);
|
|
1367
|
+
}
|
|
1368
|
+
|
|
1369
|
+
void KERNEL_4x8(int64_t ii, int64_t jj) {
|
|
1370
|
+
vec_t vec_A[4], vec_B[8] , vec_C[4];
|
|
1371
|
+
acc_t acc_0, acc_1;
|
|
1372
|
+
__builtin_mma_xxsetaccz(&acc_0);
|
|
1373
|
+
__builtin_mma_xxsetaccz(&acc_1);
|
|
1374
|
+
for (int l = 0; l < k; l+=8) {
|
|
1375
|
+
packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A);
|
|
1376
|
+
packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B);
|
|
1377
|
+
for (int x = 0; x < 4; x++) {
|
|
1378
|
+
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
|
1379
|
+
__builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
|
|
1380
|
+
}
|
|
1381
|
+
}
|
|
1382
|
+
SAVE_ACC(&acc_0, ii, jj);
|
|
1383
|
+
SAVE_ACC(&acc_1, ii, jj+4);
|
|
1384
|
+
}
|
|
1385
|
+
|
|
1386
|
+
void KERNEL_8x4(int64_t ii, int64_t jj) {
|
|
1387
|
+
vec_t vec_A[8], vec_B[4] , vec_C[4];
|
|
1388
|
+
acc_t acc_0, acc_1;
|
|
1389
|
+
__builtin_mma_xxsetaccz(&acc_0);
|
|
1390
|
+
__builtin_mma_xxsetaccz(&acc_1);
|
|
1391
|
+
for (int l = 0; l < k; l+=8) {
|
|
1392
|
+
packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A);
|
|
1393
|
+
packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
|
|
1394
|
+
for (int x = 0; x < 4; x++) {
|
|
1395
|
+
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
|
1396
|
+
__builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]);
|
|
1397
|
+
}
|
|
1398
|
+
}
|
|
1399
|
+
SAVE_ACC(&acc_0, ii, jj);
|
|
1400
|
+
SAVE_ACC(&acc_1, ii+4, jj);
|
|
1401
|
+
}
|
|
1402
|
+
|
|
1403
|
+
|
|
1404
|
+
void KERNEL_8x8(int64_t ii, int64_t jj) {
|
|
1405
|
+
vec_t vec_A[8], vec_B[8], vec_C[4];
|
|
1406
|
+
acc_t acc_0, acc_1, acc_2, acc_3;
|
|
1407
|
+
__builtin_mma_xxsetaccz(&acc_0);
|
|
1408
|
+
__builtin_mma_xxsetaccz(&acc_1);
|
|
1409
|
+
__builtin_mma_xxsetaccz(&acc_2);
|
|
1410
|
+
__builtin_mma_xxsetaccz(&acc_3);
|
|
1411
|
+
for (int l = 0; l < k; l+=8) {
|
|
1412
|
+
packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A);
|
|
1413
|
+
packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B);
|
|
1414
|
+
for (int x = 0; x < 4; x++) {
|
|
1415
|
+
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
|
1416
|
+
__builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]);
|
|
1417
|
+
__builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]);
|
|
1418
|
+
__builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]);
|
|
1419
|
+
}
|
|
1420
|
+
}
|
|
1421
|
+
|
|
1422
|
+
SAVE_ACC(&acc_0, ii, jj);
|
|
1423
|
+
SAVE_ACC(&acc_1, ii, jj+4);
|
|
1424
|
+
SAVE_ACC(&acc_2, ii+4, jj);
|
|
1425
|
+
SAVE_ACC(&acc_3, ii+4, jj+4);
|
|
1426
|
+
}
|
|
1427
|
+
|
|
1428
|
+
template<int RM, int RN>
|
|
1429
|
+
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
|
1430
|
+
int64_t ytiles = (m - m0) / RM;
|
|
1431
|
+
int64_t xtiles = (n - n0) / RN;
|
|
1432
|
+
int64_t tiles = xtiles * ytiles;
|
|
1433
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
|
1434
|
+
int64_t start = duty * ith;
|
|
1435
|
+
int64_t end = start + duty;
|
|
1436
|
+
if (end > tiles)
|
|
1437
|
+
end = tiles;
|
|
1438
|
+
for (int64_t job = start; job < end; ++job) {
|
|
1439
|
+
int64_t ii = m0 + job / xtiles * RM;
|
|
1440
|
+
int64_t jj = n0 + job % xtiles * RN;
|
|
1441
|
+
vec_t vec_C[4];
|
|
1442
|
+
acc_t acc_0;
|
|
1443
|
+
__builtin_mma_xxsetaccz(&acc_0);
|
|
1444
|
+
vec_t vec_A[2], vec_B[2];
|
|
1445
|
+
for (int l=0; l<k; l+=4) {
|
|
1446
|
+
packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A);
|
|
1447
|
+
packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B);
|
|
1448
|
+
for (int x = 0; x<2; x++) {
|
|
1449
|
+
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
|
1450
|
+
}
|
|
1451
|
+
}
|
|
1452
|
+
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
|
1453
|
+
for (int I = 0; I < RM; I++) {
|
|
1454
|
+
for (int J = 0; J < RN; J++) {
|
|
1455
|
+
*((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
|
|
1456
|
+
}
|
|
1457
|
+
}
|
|
1458
|
+
}
|
|
1459
|
+
}
|
|
1460
|
+
|
|
1461
|
+
template<int RM>
|
|
1462
|
+
void gemm_Mx8(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
|
1463
|
+
int RN = 8;
|
|
1464
|
+
int64_t ytiles = (m - m0) / RM;
|
|
1465
|
+
int64_t xtiles = (n - n0) / RN;
|
|
1466
|
+
int64_t tiles = xtiles * ytiles;
|
|
1467
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
|
1468
|
+
int64_t start = duty * ith;
|
|
1469
|
+
int64_t end = start + duty;
|
|
1470
|
+
if (end > tiles)
|
|
1471
|
+
end = tiles;
|
|
1472
|
+
for (int64_t job = start; job < end; ++job) {
|
|
1473
|
+
int64_t ii = m0 + job / xtiles * RM;
|
|
1474
|
+
int64_t jj = n0 + job % xtiles * RN;
|
|
1475
|
+
vec_t vec_C[4];
|
|
1476
|
+
acc_t acc_0, acc_1;
|
|
1477
|
+
__builtin_mma_xxsetaccz(&acc_0);
|
|
1478
|
+
__builtin_mma_xxsetaccz(&acc_1);
|
|
1479
|
+
vec_t vec_A[4], vec_B[8];
|
|
1480
|
+
for (int l=0; l<k; l+=8) {
|
|
1481
|
+
packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A);
|
|
1482
|
+
packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B);
|
|
1483
|
+
for (int x = 0; x<4; x++) {
|
|
1484
|
+
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
|
1485
|
+
__builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
|
|
1486
|
+
}
|
|
1487
|
+
}
|
|
1488
|
+
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
|
1489
|
+
for (int I = 0; I < RM; I++) {
|
|
1490
|
+
for (int J = 0; J < 4; J++) {
|
|
1491
|
+
*((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
|
|
1492
|
+
}
|
|
1493
|
+
}
|
|
1494
|
+
__builtin_mma_disassemble_acc(vec_C, &acc_1);
|
|
1495
|
+
for (int I = 0; I < RM; I++) {
|
|
1496
|
+
for (int J = 0; J < 4; J++) {
|
|
1497
|
+
*((TC*)(C+ii+((jj+4+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
|
|
1498
|
+
}
|
|
1499
|
+
}
|
|
1500
|
+
}
|
|
1501
|
+
}
|
|
1502
|
+
|
|
1503
|
+
template<int RM, int RN>
|
|
1504
|
+
inline void kernel(int64_t ii, int64_t jj) {
|
|
1505
|
+
if constexpr(RM == 4 && RN == 8) {
|
|
1506
|
+
KERNEL_4x8(ii,jj);
|
|
1507
|
+
} else if constexpr(RM == 8 && RN == 8) {
|
|
1508
|
+
KERNEL_8x8(ii,jj);
|
|
1509
|
+
} else if constexpr(RM == 8 && RN == 4) {
|
|
1510
|
+
KERNEL_8x4(ii,jj);
|
|
1511
|
+
} else {
|
|
1512
|
+
static_assert(false, "RN/RM values not supported");
|
|
1513
|
+
}
|
|
1514
|
+
}
|
|
1515
|
+
|
|
1516
|
+
template <int RM, int RN>
|
|
1517
|
+
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
|
1518
|
+
int64_t ytiles = (m - m0) / RM;
|
|
1519
|
+
int64_t xtiles = (n - n0) / RN;
|
|
1520
|
+
int64_t tiles = xtiles * ytiles;
|
|
1521
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
|
1522
|
+
int64_t start = duty * ith;
|
|
1523
|
+
int64_t end = start + duty;
|
|
1524
|
+
if (end > tiles)
|
|
1525
|
+
end = tiles;
|
|
1526
|
+
for (int64_t job = start; job < end; ++job) {
|
|
1527
|
+
int64_t ii = m0 + job / xtiles * RM;
|
|
1528
|
+
int64_t jj = n0 + job % xtiles * RN;
|
|
1529
|
+
kernel<RM, RN>(ii, jj);
|
|
1530
|
+
}
|
|
1531
|
+
}
|
|
1532
|
+
|
|
1533
|
+
const TA *const A;
|
|
1534
|
+
const TB *const B;
|
|
1535
|
+
TC *C;
|
|
1536
|
+
const int64_t k;
|
|
1537
|
+
const int64_t lda;
|
|
1538
|
+
const int64_t ldb;
|
|
1539
|
+
const int64_t ldc;
|
|
1540
|
+
const int ith;
|
|
1541
|
+
const int nth;
|
|
1542
|
+
};
|
|
1543
|
+
|
|
1056
1544
|
template <typename TA, typename TB, typename TC>
|
|
1057
1545
|
class tinyBLAS_Q0_PPC {
|
|
1058
1546
|
public:
|
|
@@ -1092,13 +1580,403 @@ class tinyBLAS_Q0_PPC {
|
|
|
1092
1580
|
}
|
|
1093
1581
|
}
|
|
1094
1582
|
|
|
1095
|
-
template<typename VA, typename VB>
|
|
1096
|
-
void
|
|
1583
|
+
template<typename VA, typename VB, int size>
|
|
1584
|
+
void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, VA* vec, std::array<int, size>& comparray) {
|
|
1097
1585
|
int64_t i, j;
|
|
1098
1586
|
TA *aoffset = NULL;
|
|
1099
1587
|
VA *vecOffset = NULL;
|
|
1100
1588
|
TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
|
|
1101
1589
|
TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
|
|
1590
|
+
VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
|
|
1591
|
+
VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
|
|
1592
|
+
VB t1, t2, t3, t4, t5, t6, t7, t8;
|
|
1593
|
+
const vector signed char lowMask = vec_splats((signed char)0xF);
|
|
1594
|
+
const vector unsigned char v4 = vec_splats((unsigned char)0x4);
|
|
1595
|
+
const vector signed char v8 = vec_splats((signed char)0x8);
|
|
1596
|
+
aoffset = const_cast<TA*>(a);
|
|
1597
|
+
vecOffset = vec;
|
|
1598
|
+
vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
|
1599
|
+
vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
|
1600
|
+
vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
|
|
1601
|
+
vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
|
|
1602
|
+
vector signed int vsum = {0};
|
|
1603
|
+
vector signed int vsum2 = {0};
|
|
1604
|
+
|
|
1605
|
+
j = (rows >> 3);
|
|
1606
|
+
if (j > 0) {
|
|
1607
|
+
do {
|
|
1608
|
+
aoffset1 = aoffset;
|
|
1609
|
+
aoffset2 = aoffset1 + lda;
|
|
1610
|
+
aoffset3 = aoffset2 + lda;
|
|
1611
|
+
aoffset4 = aoffset3 + lda;
|
|
1612
|
+
aoffset5 = aoffset4 + lda;
|
|
1613
|
+
aoffset6 = aoffset5 + lda;
|
|
1614
|
+
aoffset7 = aoffset6 + lda;
|
|
1615
|
+
aoffset8 = aoffset7 + lda;
|
|
1616
|
+
aoffset += 8 * lda;
|
|
1617
|
+
|
|
1618
|
+
i = (cols >> 2);
|
|
1619
|
+
if (i > 0) {
|
|
1620
|
+
do {
|
|
1621
|
+
c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
|
|
1622
|
+
c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
|
|
1623
|
+
c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
|
|
1624
|
+
c4[1] = reinterpret_cast<VB>(vec_xl(0, aoffset4->qs));
|
|
1625
|
+
c5[1] = reinterpret_cast<VB>(vec_xl(0, aoffset5->qs));
|
|
1626
|
+
c6[1] = reinterpret_cast<VB>(vec_xl(0, aoffset6->qs));
|
|
1627
|
+
c7[1] = reinterpret_cast<VB>(vec_xl(0, aoffset7->qs));
|
|
1628
|
+
c8[1] = reinterpret_cast<VB>(vec_xl(0, aoffset8->qs));
|
|
1629
|
+
|
|
1630
|
+
c1[0] = vec_and(c1[1], lowMask);
|
|
1631
|
+
c1[1] = vec_sr(c1[1], v4);
|
|
1632
|
+
c1[0] = vec_sub(c1[0], v8);
|
|
1633
|
+
c1[1] = vec_sub(c1[1], v8);
|
|
1634
|
+
vsum = vec_sum4s(c1[0], vsum);
|
|
1635
|
+
vsum2 = vec_sum4s(c1[1], vsum2);
|
|
1636
|
+
vsum = vec_add(vsum, vsum2);
|
|
1637
|
+
comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
|
1638
|
+
vsum = vec_splats(0);
|
|
1639
|
+
vsum2 = vec_splats(0);
|
|
1640
|
+
|
|
1641
|
+
c2[0] = vec_and(c2[1], lowMask);
|
|
1642
|
+
c2[1] = vec_sr(c2[1], v4);
|
|
1643
|
+
c2[0] = vec_sub(c2[0], v8);
|
|
1644
|
+
c2[1] = vec_sub(c2[1], v8);
|
|
1645
|
+
vsum = vec_sum4s(c2[0], vsum);
|
|
1646
|
+
vsum2 = vec_sum4s(c2[1], vsum2);
|
|
1647
|
+
vsum = vec_add(vsum, vsum2);
|
|
1648
|
+
comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
|
1649
|
+
vsum = vec_splats(0);
|
|
1650
|
+
vsum2 = vec_splats(0);
|
|
1651
|
+
|
|
1652
|
+
c3[0] = vec_and(c3[1], lowMask);
|
|
1653
|
+
c3[1] = vec_sr(c3[1], v4);
|
|
1654
|
+
c3[0] = vec_sub(c3[0], v8);
|
|
1655
|
+
c3[1] = vec_sub(c3[1], v8);
|
|
1656
|
+
vsum = vec_sum4s(c3[0], vsum);
|
|
1657
|
+
vsum2 = vec_sum4s(c3[1], vsum2);
|
|
1658
|
+
vsum = vec_add(vsum, vsum2);
|
|
1659
|
+
comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
|
1660
|
+
vsum = vec_splats(0);
|
|
1661
|
+
vsum2 = vec_splats(0);
|
|
1662
|
+
|
|
1663
|
+
c4[0] = vec_and(c4[1], lowMask);
|
|
1664
|
+
c4[1] = vec_sr(c4[1], v4);
|
|
1665
|
+
c4[0] = vec_sub(c4[0], v8);
|
|
1666
|
+
c4[1] = vec_sub(c4[1], v8);
|
|
1667
|
+
vsum = vec_sum4s(c4[0], vsum);
|
|
1668
|
+
vsum2 = vec_sum4s(c4[1], vsum2);
|
|
1669
|
+
vsum = vec_add(vsum, vsum2);
|
|
1670
|
+
comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
|
1671
|
+
vsum = vec_splats(0);
|
|
1672
|
+
vsum2 = vec_splats(0);
|
|
1673
|
+
|
|
1674
|
+
c5[0] = vec_and(c5[1], lowMask);
|
|
1675
|
+
c5[1] = vec_sr(c5[1], v4);
|
|
1676
|
+
c5[0] = vec_sub(c5[0], v8);
|
|
1677
|
+
c5[1] = vec_sub(c5[1], v8);
|
|
1678
|
+
vsum = vec_sum4s(c5[0], vsum);
|
|
1679
|
+
vsum2 = vec_sum4s(c5[1], vsum2);
|
|
1680
|
+
vsum = vec_add(vsum, vsum2);
|
|
1681
|
+
comparray[4] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
|
1682
|
+
vsum = vec_splats(0);
|
|
1683
|
+
vsum2 = vec_splats(0);
|
|
1684
|
+
|
|
1685
|
+
c6[0] = vec_and(c6[1], lowMask);
|
|
1686
|
+
c6[1] = vec_sr(c6[1], v4);
|
|
1687
|
+
c6[0] = vec_sub(c6[0], v8);
|
|
1688
|
+
c6[1] = vec_sub(c6[1], v8);
|
|
1689
|
+
vsum = vec_sum4s(c6[0], vsum);
|
|
1690
|
+
vsum2 = vec_sum4s(c6[1], vsum2);
|
|
1691
|
+
vsum = vec_add(vsum, vsum2);
|
|
1692
|
+
comparray[5] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
|
1693
|
+
vsum = vec_splats(0);
|
|
1694
|
+
vsum2 = vec_splats(0);
|
|
1695
|
+
|
|
1696
|
+
c7[0] = vec_and(c7[1], lowMask);
|
|
1697
|
+
c7[1] = vec_sr(c7[1], v4);
|
|
1698
|
+
c7[0] = vec_sub(c7[0], v8);
|
|
1699
|
+
c7[1] = vec_sub(c7[1], v8);
|
|
1700
|
+
vsum = vec_sum4s(c7[0], vsum);
|
|
1701
|
+
vsum2 = vec_sum4s(c7[1], vsum2);
|
|
1702
|
+
vsum = vec_add(vsum, vsum2);
|
|
1703
|
+
comparray[6] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
|
1704
|
+
vsum = vec_splats(0);
|
|
1705
|
+
vsum2 = vec_splats(0);
|
|
1706
|
+
|
|
1707
|
+
c8[0] = vec_and(c8[1], lowMask);
|
|
1708
|
+
c8[1] = vec_sr(c8[1], v4);
|
|
1709
|
+
c8[0] = vec_sub(c8[0], v8);
|
|
1710
|
+
c8[1] = vec_sub(c8[1], v8);
|
|
1711
|
+
vsum = vec_sum4s(c8[0], vsum);
|
|
1712
|
+
vsum2 = vec_sum4s(c8[1], vsum2);
|
|
1713
|
+
vsum = vec_add(vsum, vsum2);
|
|
1714
|
+
comparray[7] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
|
1715
|
+
vsum = vec_splats(0);
|
|
1716
|
+
vsum2 = vec_splats(0);
|
|
1717
|
+
|
|
1718
|
+
t1 = vec_perm(c1[0], c2[0], swiz1);
|
|
1719
|
+
t2 = vec_perm(c1[0], c2[0], swiz2);
|
|
1720
|
+
t3 = vec_perm(c3[0], c4[0], swiz1);
|
|
1721
|
+
t4 = vec_perm(c3[0], c4[0], swiz2);
|
|
1722
|
+
t5 = vec_perm(t1, t3, swiz3);
|
|
1723
|
+
t6 = vec_perm(t1, t3, swiz4);
|
|
1724
|
+
t7 = vec_perm(t2, t4, swiz3);
|
|
1725
|
+
t8 = vec_perm(t2, t4, swiz4);
|
|
1726
|
+
vec_xst(t5, 0, vecOffset);
|
|
1727
|
+
vec_xst(t6, 0, vecOffset+16);
|
|
1728
|
+
vec_xst(t7, 0, vecOffset+32);
|
|
1729
|
+
vec_xst(t8, 0, vecOffset+48);
|
|
1730
|
+
|
|
1731
|
+
t1 = vec_perm(c1[1], c2[1], swiz1);
|
|
1732
|
+
t2 = vec_perm(c1[1], c2[1], swiz2);
|
|
1733
|
+
t3 = vec_perm(c3[1], c4[1], swiz1);
|
|
1734
|
+
t4 = vec_perm(c3[1], c4[1], swiz2);
|
|
1735
|
+
t5 = vec_perm(t1, t3, swiz3);
|
|
1736
|
+
t6 = vec_perm(t1, t3, swiz4);
|
|
1737
|
+
t7 = vec_perm(t2, t4, swiz3);
|
|
1738
|
+
t8 = vec_perm(t2, t4, swiz4);
|
|
1739
|
+
vec_xst(t5, 0, vecOffset+64);
|
|
1740
|
+
vec_xst(t6, 0, vecOffset+80);
|
|
1741
|
+
vec_xst(t7, 0, vecOffset+96);
|
|
1742
|
+
vec_xst(t8, 0, vecOffset+112);
|
|
1743
|
+
|
|
1744
|
+
t1 = vec_perm(c5[0], c6[0], swiz1);
|
|
1745
|
+
t2 = vec_perm(c5[0], c6[0], swiz2);
|
|
1746
|
+
t3 = vec_perm(c7[0], c8[0], swiz1);
|
|
1747
|
+
t4 = vec_perm(c7[0], c8[0], swiz2);
|
|
1748
|
+
t5 = vec_perm(t1, t3, swiz3);
|
|
1749
|
+
t6 = vec_perm(t1, t3, swiz4);
|
|
1750
|
+
t7 = vec_perm(t2, t4, swiz3);
|
|
1751
|
+
t8 = vec_perm(t2, t4, swiz4);
|
|
1752
|
+
vec_xst(t5, 0, vecOffset+128);
|
|
1753
|
+
vec_xst(t6, 0, vecOffset+144);
|
|
1754
|
+
vec_xst(t7, 0, vecOffset+160);
|
|
1755
|
+
vec_xst(t8, 0, vecOffset+176);
|
|
1756
|
+
|
|
1757
|
+
t1 = vec_perm(c5[1], c6[1], swiz1);
|
|
1758
|
+
t2 = vec_perm(c5[1], c6[1], swiz2);
|
|
1759
|
+
t3 = vec_perm(c7[1], c8[1], swiz1);
|
|
1760
|
+
t4 = vec_perm(c7[1], c8[1], swiz2);
|
|
1761
|
+
t5 = vec_perm(t1, t3, swiz3);
|
|
1762
|
+
t6 = vec_perm(t1, t3, swiz4);
|
|
1763
|
+
t7 = vec_perm(t2, t4, swiz3);
|
|
1764
|
+
t8 = vec_perm(t2, t4, swiz4);
|
|
1765
|
+
vec_xst(t5, 0, vecOffset+192);
|
|
1766
|
+
vec_xst(t6, 0, vecOffset+208);
|
|
1767
|
+
vec_xst(t7, 0, vecOffset+224);
|
|
1768
|
+
vec_xst(t8, 0, vecOffset+240);
|
|
1769
|
+
|
|
1770
|
+
aoffset1 += lda;
|
|
1771
|
+
aoffset2 += lda;
|
|
1772
|
+
aoffset3 += lda;
|
|
1773
|
+
aoffset4 += lda;
|
|
1774
|
+
aoffset5 += lda;
|
|
1775
|
+
aoffset6 += lda;
|
|
1776
|
+
aoffset7 += lda;
|
|
1777
|
+
aoffset8 += lda;
|
|
1778
|
+
vecOffset += 256;
|
|
1779
|
+
i--;
|
|
1780
|
+
} while (i > 0);
|
|
1781
|
+
}
|
|
1782
|
+
j--;
|
|
1783
|
+
} while (j > 0);
|
|
1784
|
+
}
|
|
1785
|
+
|
|
1786
|
+
if (rows & 4) {
|
|
1787
|
+
aoffset1 = aoffset;
|
|
1788
|
+
aoffset2 = aoffset1 + lda;
|
|
1789
|
+
aoffset3 = aoffset2 + lda;
|
|
1790
|
+
aoffset4 = aoffset3 + lda;
|
|
1791
|
+
aoffset += 4 * lda;
|
|
1792
|
+
|
|
1793
|
+
i = (cols >> 2);
|
|
1794
|
+
if (i > 0) {
|
|
1795
|
+
do {
|
|
1796
|
+
c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
|
|
1797
|
+
c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
|
|
1798
|
+
c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
|
|
1799
|
+
c4[1] = reinterpret_cast<VB>(vec_xl(0, aoffset4->qs));
|
|
1800
|
+
|
|
1801
|
+
c1[0] = vec_and(c1[1], lowMask);
|
|
1802
|
+
c1[1] = vec_sr(c1[1], v4);
|
|
1803
|
+
c1[0] = vec_sub(c1[0], v8);
|
|
1804
|
+
c1[1] = vec_sub(c1[1], v8);
|
|
1805
|
+
vsum = vec_sum4s(c1[0], vsum);
|
|
1806
|
+
vsum2 = vec_sum4s(c1[1], vsum2);
|
|
1807
|
+
vsum = vec_add(vsum, vsum2);
|
|
1808
|
+
comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
|
1809
|
+
vsum = vec_splats(0);
|
|
1810
|
+
vsum2 = vec_splats(0);
|
|
1811
|
+
|
|
1812
|
+
c2[0] = vec_and(c2[1], lowMask);
|
|
1813
|
+
c2[1] = vec_sr(c2[1], v4);
|
|
1814
|
+
c2[0] = vec_sub(c2[0], v8);
|
|
1815
|
+
c2[1] = vec_sub(c2[1], v8);
|
|
1816
|
+
vsum = vec_sum4s(c2[0], vsum);
|
|
1817
|
+
vsum2 = vec_sum4s(c2[1], vsum2);
|
|
1818
|
+
vsum = vec_add(vsum, vsum2);
|
|
1819
|
+
comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
|
1820
|
+
vsum = vec_splats(0);
|
|
1821
|
+
vsum2 = vec_splats(0);
|
|
1822
|
+
|
|
1823
|
+
c3[0] = vec_and(c3[1], lowMask);
|
|
1824
|
+
c3[1] = vec_sr(c3[1], v4);
|
|
1825
|
+
c3[0] = vec_sub(c3[0], v8);
|
|
1826
|
+
c3[1] = vec_sub(c3[1], v8);
|
|
1827
|
+
vsum = vec_sum4s(c3[0], vsum);
|
|
1828
|
+
vsum2 = vec_sum4s(c3[1], vsum2);
|
|
1829
|
+
vsum = vec_add(vsum, vsum2);
|
|
1830
|
+
comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
|
1831
|
+
vsum = vec_splats(0);
|
|
1832
|
+
vsum2 = vec_splats(0);
|
|
1833
|
+
|
|
1834
|
+
c4[0] = vec_and(c4[1], lowMask);
|
|
1835
|
+
c4[1] = vec_sr(c4[1], v4);
|
|
1836
|
+
c4[0] = vec_sub(c4[0], v8);
|
|
1837
|
+
c4[1] = vec_sub(c4[1], v8);
|
|
1838
|
+
vsum = vec_sum4s(c4[0], vsum);
|
|
1839
|
+
vsum2 = vec_sum4s(c4[1], vsum2);
|
|
1840
|
+
vsum = vec_add(vsum, vsum2);
|
|
1841
|
+
comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
|
1842
|
+
vsum = vec_splats(0);
|
|
1843
|
+
vsum2 = vec_splats( 0);
|
|
1844
|
+
|
|
1845
|
+
t1 = vec_perm(c1[0], c2[0], swiz1);
|
|
1846
|
+
t2 = vec_perm(c1[0], c2[0], swiz2);
|
|
1847
|
+
t3 = vec_perm(c3[0], c4[0], swiz1);
|
|
1848
|
+
t4 = vec_perm(c3[0], c4[0], swiz2);
|
|
1849
|
+
t5 = vec_perm(t1, t3, swiz3);
|
|
1850
|
+
t6 = vec_perm(t1, t3, swiz4);
|
|
1851
|
+
t7 = vec_perm(t2, t4, swiz3);
|
|
1852
|
+
t8 = vec_perm(t2, t4, swiz4);
|
|
1853
|
+
vec_xst(t5, 0, vecOffset);
|
|
1854
|
+
vec_xst(t6, 0, vecOffset+16);
|
|
1855
|
+
vec_xst(t7, 0, vecOffset+32);
|
|
1856
|
+
vec_xst(t8, 0, vecOffset+48);
|
|
1857
|
+
|
|
1858
|
+
t1 = vec_perm(c1[1], c2[1], swiz1);
|
|
1859
|
+
t2 = vec_perm(c1[1], c2[1], swiz2);
|
|
1860
|
+
t3 = vec_perm(c3[1], c4[1], swiz1);
|
|
1861
|
+
t4 = vec_perm(c3[1], c4[1], swiz2);
|
|
1862
|
+
t5 = vec_perm(t1, t3, swiz3);
|
|
1863
|
+
t6 = vec_perm(t1, t3, swiz4);
|
|
1864
|
+
t7 = vec_perm(t2, t4, swiz3);
|
|
1865
|
+
t8 = vec_perm(t2, t4, swiz4);
|
|
1866
|
+
vec_xst(t5, 0, vecOffset+64);
|
|
1867
|
+
vec_xst(t6, 0, vecOffset+80);
|
|
1868
|
+
vec_xst(t7, 0, vecOffset+96);
|
|
1869
|
+
vec_xst(t8, 0, vecOffset+112);
|
|
1870
|
+
|
|
1871
|
+
aoffset1 += lda;
|
|
1872
|
+
aoffset2 += lda;
|
|
1873
|
+
aoffset3 += lda;
|
|
1874
|
+
aoffset4 += lda;
|
|
1875
|
+
vecOffset += 128;
|
|
1876
|
+
i--;
|
|
1877
|
+
} while (i > 0);
|
|
1878
|
+
}
|
|
1879
|
+
}
|
|
1880
|
+
|
|
1881
|
+
if (rows & 3) {
|
|
1882
|
+
aoffset1 = aoffset;
|
|
1883
|
+
aoffset2 = aoffset1 + lda;
|
|
1884
|
+
aoffset3 = aoffset2 + lda;
|
|
1885
|
+
i = (cols >> 2);
|
|
1886
|
+
if (i > 0) {
|
|
1887
|
+
do {
|
|
1888
|
+
switch(rows) {
|
|
1889
|
+
case 3: c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
|
|
1890
|
+
case 2: c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
|
|
1891
|
+
case 1: c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
|
|
1892
|
+
break;
|
|
1893
|
+
}
|
|
1894
|
+
c1[0] = vec_and(c1[1], lowMask);
|
|
1895
|
+
c1[1] = vec_sr(c1[1], v4);
|
|
1896
|
+
c1[0] = vec_sub(c1[0], v8);
|
|
1897
|
+
c1[1] = vec_sub(c1[1], v8);
|
|
1898
|
+
vsum = vec_sum4s(c1[0], vsum);
|
|
1899
|
+
vsum2 = vec_sum4s(c1[1], vsum2);
|
|
1900
|
+
vsum = vec_add(vsum, vsum2);
|
|
1901
|
+
comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
|
1902
|
+
vsum = vec_splats(0);
|
|
1903
|
+
vsum2 = vec_splats(0);
|
|
1904
|
+
|
|
1905
|
+
c2[0] = vec_and(c2[1], lowMask);
|
|
1906
|
+
c2[1] = vec_sr(c2[1], v4);
|
|
1907
|
+
c2[0] = vec_sub(c2[0], v8);
|
|
1908
|
+
c2[1] = vec_sub(c2[1], v8);
|
|
1909
|
+
vsum = vec_sum4s(c2[0], vsum);
|
|
1910
|
+
vsum2 = vec_sum4s(c2[1], vsum2);
|
|
1911
|
+
vsum = vec_add(vsum, vsum2);
|
|
1912
|
+
comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
|
1913
|
+
vsum = vec_splats(0);
|
|
1914
|
+
vsum2 = vec_splats(0);
|
|
1915
|
+
|
|
1916
|
+
c3[0] = vec_and(c3[1], lowMask);
|
|
1917
|
+
c3[1] = vec_sr(c3[1], v4);
|
|
1918
|
+
c3[0] = vec_sub(c3[0], v8);
|
|
1919
|
+
c3[1] = vec_sub(c3[1], v8);
|
|
1920
|
+
vsum = vec_sum4s(c3[0], vsum);
|
|
1921
|
+
vsum2 = vec_sum4s(c3[1], vsum2);
|
|
1922
|
+
vsum = vec_add(vsum, vsum2);
|
|
1923
|
+
comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
|
1924
|
+
vsum = vec_splats(0);
|
|
1925
|
+
vsum2 = vec_splats(0);
|
|
1926
|
+
|
|
1927
|
+
c4[0] = vec_and(c4[1], lowMask);
|
|
1928
|
+
c4[1] = vec_sr(c4[1], v4);
|
|
1929
|
+
c4[0] = vec_sub(c4[0], v8);
|
|
1930
|
+
c4[1] = vec_sub(c4[1], v8);
|
|
1931
|
+
vsum = vec_sum4s(c4[0], vsum);
|
|
1932
|
+
vsum2 = vec_sum4s(c4[1], vsum2);
|
|
1933
|
+
vsum = vec_add(vsum, vsum2);
|
|
1934
|
+
comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
|
1935
|
+
vsum = vec_splats(0);
|
|
1936
|
+
vsum2 = vec_splats(0);
|
|
1937
|
+
|
|
1938
|
+
t1 = vec_perm(c1[0], c2[0], swiz1);
|
|
1939
|
+
t2 = vec_perm(c1[0], c2[0], swiz2);
|
|
1940
|
+
t3 = vec_perm(c3[0], c4[0], swiz1);
|
|
1941
|
+
t4 = vec_perm(c3[0], c4[0], swiz2);
|
|
1942
|
+
t5 = vec_perm(t1, t3, swiz3);
|
|
1943
|
+
t6 = vec_perm(t1, t3, swiz4);
|
|
1944
|
+
t7 = vec_perm(t2, t4, swiz3);
|
|
1945
|
+
t8 = vec_perm(t2, t4, swiz4);
|
|
1946
|
+
vec_xst(t5, 0, vecOffset);
|
|
1947
|
+
vec_xst(t6, 0, vecOffset+16);
|
|
1948
|
+
vec_xst(t7, 0, vecOffset+32);
|
|
1949
|
+
vec_xst(t8, 0, vecOffset+48);
|
|
1950
|
+
|
|
1951
|
+
t1 = vec_perm(c1[1], c2[1], swiz1);
|
|
1952
|
+
t2 = vec_perm(c1[1], c2[1], swiz2);
|
|
1953
|
+
t3 = vec_perm(c3[1], c4[1], swiz1);
|
|
1954
|
+
t4 = vec_perm(c3[1], c4[1], swiz2);
|
|
1955
|
+
t5 = vec_perm(t1, t3, swiz3);
|
|
1956
|
+
t6 = vec_perm(t1, t3, swiz4);
|
|
1957
|
+
t7 = vec_perm(t2, t4, swiz3);
|
|
1958
|
+
t8 = vec_perm(t2, t4, swiz4);
|
|
1959
|
+
vec_xst(t5, 0, vecOffset+64);
|
|
1960
|
+
vec_xst(t6, 0, vecOffset+80);
|
|
1961
|
+
vec_xst(t7, 0, vecOffset+96);
|
|
1962
|
+
vec_xst(t8, 0, vecOffset+112);
|
|
1963
|
+
aoffset1 += lda;
|
|
1964
|
+
aoffset2 += lda;
|
|
1965
|
+
aoffset3 += lda;
|
|
1966
|
+
vecOffset += 128;
|
|
1967
|
+
i--;
|
|
1968
|
+
} while(i > 0);
|
|
1969
|
+
}
|
|
1970
|
+
}
|
|
1971
|
+
}
|
|
1972
|
+
|
|
1973
|
+
template<typename VA, typename VB>
|
|
1974
|
+
void packNormal(const TB* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
|
|
1975
|
+
int64_t i, j;
|
|
1976
|
+
TB *aoffset = NULL;
|
|
1977
|
+
VA *vecOffset = NULL;
|
|
1978
|
+
TB *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
|
|
1979
|
+
TB *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
|
|
1102
1980
|
__vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
|
|
1103
1981
|
VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0};
|
|
1104
1982
|
VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0};
|
|
@@ -1111,24 +1989,24 @@ class tinyBLAS_Q0_PPC {
|
|
|
1111
1989
|
vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
|
|
1112
1990
|
vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
|
|
1113
1991
|
|
|
1114
|
-
aoffset = const_cast<
|
|
1992
|
+
aoffset = const_cast<TB*>(a);
|
|
1115
1993
|
vecOffset = vec;
|
|
1116
1994
|
j = (rows >> 3);
|
|
1117
1995
|
if (j > 0) {
|
|
1118
1996
|
do {
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1997
|
+
aoffset1 = aoffset;
|
|
1998
|
+
aoffset2 = aoffset1 + lda;
|
|
1999
|
+
aoffset3 = aoffset2 + lda;
|
|
2000
|
+
aoffset4 = aoffset3 + lda;
|
|
2001
|
+
aoffset5 = aoffset4 + lda;
|
|
2002
|
+
aoffset6 = aoffset5 + lda;
|
|
2003
|
+
aoffset7 = aoffset6 + lda;
|
|
2004
|
+
aoffset8 = aoffset7 + lda;
|
|
2005
|
+
aoffset += 8 * lda;
|
|
1128
2006
|
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
2007
|
+
i = (cols >> 3);
|
|
2008
|
+
if (i > 0) {
|
|
2009
|
+
do {
|
|
1132
2010
|
C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
|
|
1133
2011
|
C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
|
|
1134
2012
|
C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
|
|
@@ -1156,10 +2034,10 @@ class tinyBLAS_Q0_PPC {
|
|
|
1156
2034
|
t7 = vec_perm(t2, t4, swiz3);
|
|
1157
2035
|
t8 = vec_perm(t2, t4, swiz4);
|
|
1158
2036
|
if (flip == true) {
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
2037
|
+
t5 = vec_xor(t5, xor_vector);
|
|
2038
|
+
t6 = vec_xor(t6, xor_vector);
|
|
2039
|
+
t7 = vec_xor(t7, xor_vector);
|
|
2040
|
+
t8 = vec_xor(t8, xor_vector);
|
|
1163
2041
|
}
|
|
1164
2042
|
vec_xst(t5, 0, vecOffset);
|
|
1165
2043
|
vec_xst(t6, 0, vecOffset+16);
|
|
@@ -1175,10 +2053,10 @@ class tinyBLAS_Q0_PPC {
|
|
|
1175
2053
|
t7 = vec_perm(t2, t4, swiz3);
|
|
1176
2054
|
t8 = vec_perm(t2, t4, swiz4);
|
|
1177
2055
|
if (flip == true) {
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
2056
|
+
t5 = vec_xor(t5, xor_vector);
|
|
2057
|
+
t6 = vec_xor(t6, xor_vector);
|
|
2058
|
+
t7 = vec_xor(t7, xor_vector);
|
|
2059
|
+
t8 = vec_xor(t8, xor_vector);
|
|
1182
2060
|
}
|
|
1183
2061
|
vec_xst(t5, 0, vecOffset+64);
|
|
1184
2062
|
vec_xst(t6, 0, vecOffset+80);
|
|
@@ -1194,10 +2072,10 @@ class tinyBLAS_Q0_PPC {
|
|
|
1194
2072
|
t7 = vec_perm(t2, t4, swiz3);
|
|
1195
2073
|
t8 = vec_perm(t2, t4, swiz4);
|
|
1196
2074
|
if (flip == true) {
|
|
1197
|
-
|
|
1198
|
-
|
|
1199
|
-
|
|
1200
|
-
|
|
2075
|
+
t5 = vec_xor(t5, xor_vector);
|
|
2076
|
+
t6 = vec_xor(t6, xor_vector);
|
|
2077
|
+
t7 = vec_xor(t7, xor_vector);
|
|
2078
|
+
t8 = vec_xor(t8, xor_vector);
|
|
1201
2079
|
}
|
|
1202
2080
|
vec_xst(t5, 0, vecOffset+128);
|
|
1203
2081
|
vec_xst(t6, 0, vecOffset+144);
|
|
@@ -1213,10 +2091,10 @@ class tinyBLAS_Q0_PPC {
|
|
|
1213
2091
|
t7 = vec_perm(t2, t4, swiz3);
|
|
1214
2092
|
t8 = vec_perm(t2, t4, swiz4);
|
|
1215
2093
|
if (flip == true) {
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
1219
|
-
|
|
2094
|
+
t5 = vec_xor(t5, xor_vector);
|
|
2095
|
+
t6 = vec_xor(t6, xor_vector);
|
|
2096
|
+
t7 = vec_xor(t7, xor_vector);
|
|
2097
|
+
t8 = vec_xor(t8, xor_vector);
|
|
1220
2098
|
}
|
|
1221
2099
|
vec_xst(t5, 0, vecOffset+192);
|
|
1222
2100
|
vec_xst(t6, 0, vecOffset+208);
|
|
@@ -1240,11 +2118,11 @@ class tinyBLAS_Q0_PPC {
|
|
|
1240
2118
|
}
|
|
1241
2119
|
|
|
1242
2120
|
if (rows & 4) {
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
|
|
2121
|
+
aoffset1 = aoffset;
|
|
2122
|
+
aoffset2 = aoffset1 + lda;
|
|
2123
|
+
aoffset3 = aoffset2 + lda;
|
|
2124
|
+
aoffset4 = aoffset3 + lda;
|
|
2125
|
+
aoffset += 4 * lda;
|
|
1248
2126
|
|
|
1249
2127
|
i = (cols >> 3);
|
|
1250
2128
|
if (i > 0) {
|
|
@@ -1311,7 +2189,7 @@ class tinyBLAS_Q0_PPC {
|
|
|
1311
2189
|
aoffset2 = aoffset1 + lda;
|
|
1312
2190
|
aoffset3 = aoffset2 + lda;
|
|
1313
2191
|
i = (cols >> 3);
|
|
1314
|
-
|
|
2192
|
+
if (i > 0) {
|
|
1315
2193
|
do {
|
|
1316
2194
|
switch(rows) {
|
|
1317
2195
|
case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
|
|
@@ -1527,13 +2405,18 @@ class tinyBLAS_Q0_PPC {
|
|
|
1527
2405
|
void KERNEL_4x8(int64_t ii, int64_t jj) {
|
|
1528
2406
|
vec_t vec_A[8], vec_B[16] = {0};
|
|
1529
2407
|
acc_t acc_0, acc_1;
|
|
1530
|
-
std::array<int, 4> comparray;
|
|
2408
|
+
std::array<int, 4> comparray {};
|
|
1531
2409
|
vector float fin_res[8] = {0};
|
|
1532
2410
|
vector float vs[8] = {0};
|
|
2411
|
+
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
|
1533
2412
|
for (int l = 0; l < k; l++) {
|
|
1534
2413
|
__builtin_mma_xxsetaccz(&acc_0);
|
|
1535
2414
|
__builtin_mma_xxsetaccz(&acc_1);
|
|
1536
|
-
|
|
2415
|
+
if (std::is_same_v<TA, block_q4_0>) {
|
|
2416
|
+
packNormalInt4<int8_t, vector signed char, 4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
|
|
2417
|
+
} else {
|
|
2418
|
+
packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
|
|
2419
|
+
}
|
|
1537
2420
|
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
|
|
1538
2421
|
for(int x = 0; x < 8; x++) {
|
|
1539
2422
|
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
|
@@ -1545,15 +2428,17 @@ class tinyBLAS_Q0_PPC {
|
|
|
1545
2428
|
*((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
|
|
1546
2429
|
}
|
|
1547
2430
|
}
|
|
1548
|
-
|
|
1549
|
-
|
|
1550
|
-
|
|
1551
|
-
|
|
1552
|
-
|
|
1553
|
-
|
|
1554
|
-
|
|
1555
|
-
|
|
1556
|
-
|
|
2431
|
+
if (!isAblock_q4) {
|
|
2432
|
+
auto aoffset = A+(ii*lda)+l;
|
|
2433
|
+
for (int i = 0; i < 4; i++) {
|
|
2434
|
+
comparray[i] = 0;
|
|
2435
|
+
int ca = 0;
|
|
2436
|
+
auto *at = aoffset->qs;
|
|
2437
|
+
for (int j = 0; j < 32; j++)
|
|
2438
|
+
ca += (int)*at++;
|
|
2439
|
+
comparray[i] = ca;
|
|
2440
|
+
aoffset += lda;
|
|
2441
|
+
}
|
|
1557
2442
|
}
|
|
1558
2443
|
compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
|
|
1559
2444
|
compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
|
|
@@ -1565,13 +2450,18 @@ class tinyBLAS_Q0_PPC {
|
|
|
1565
2450
|
void KERNEL_8x4(int64_t ii, int64_t jj) {
|
|
1566
2451
|
vec_t vec_A[16], vec_B[8] = {0};
|
|
1567
2452
|
acc_t acc_0, acc_1;
|
|
1568
|
-
std::array<int, 8> comparray;
|
|
2453
|
+
std::array<int, 8> comparray {};
|
|
1569
2454
|
vector float fin_res[8] = {0};
|
|
1570
2455
|
vector float vs[8] = {0};
|
|
2456
|
+
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
|
1571
2457
|
for (int l = 0; l < k; l++) {
|
|
1572
2458
|
__builtin_mma_xxsetaccz(&acc_0);
|
|
1573
2459
|
__builtin_mma_xxsetaccz(&acc_1);
|
|
1574
|
-
|
|
2460
|
+
if (std::is_same_v<TA, block_q4_0>) {
|
|
2461
|
+
packNormalInt4<int8_t, vector signed char, 8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
|
|
2462
|
+
} else {
|
|
2463
|
+
packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
|
|
2464
|
+
}
|
|
1575
2465
|
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
|
|
1576
2466
|
for(int x = 0; x < 8; x++) {
|
|
1577
2467
|
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
|
@@ -1582,15 +2472,17 @@ class tinyBLAS_Q0_PPC {
|
|
|
1582
2472
|
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
|
1583
2473
|
}
|
|
1584
2474
|
}
|
|
1585
|
-
|
|
1586
|
-
|
|
1587
|
-
|
|
1588
|
-
|
|
1589
|
-
|
|
1590
|
-
|
|
1591
|
-
|
|
1592
|
-
|
|
1593
|
-
|
|
2475
|
+
if (!isAblock_q4) {
|
|
2476
|
+
auto aoffset = A+(ii*lda)+l;
|
|
2477
|
+
for (int i = 0; i < 8; i++) {
|
|
2478
|
+
comparray[i] = 0;
|
|
2479
|
+
int ca = 0;
|
|
2480
|
+
auto *at = aoffset->qs;
|
|
2481
|
+
for (int j = 0; j < 32; j++)
|
|
2482
|
+
ca += (int)*at++;
|
|
2483
|
+
comparray[i] = ca;
|
|
2484
|
+
aoffset += lda;
|
|
2485
|
+
}
|
|
1594
2486
|
}
|
|
1595
2487
|
compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
|
|
1596
2488
|
compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
|
|
@@ -1602,15 +2494,20 @@ class tinyBLAS_Q0_PPC {
|
|
|
1602
2494
|
void KERNEL_8x8(int64_t ii, int64_t jj) {
|
|
1603
2495
|
vec_t vec_A[16], vec_B[16] = {0};
|
|
1604
2496
|
acc_t acc_0, acc_1, acc_2, acc_3;
|
|
1605
|
-
std::array<int, 8> comparray;
|
|
2497
|
+
std::array<int, 8> comparray {};
|
|
1606
2498
|
vector float fin_res[16] = {0};
|
|
1607
2499
|
vector float vs[16] = {0};
|
|
2500
|
+
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
|
1608
2501
|
for (int l = 0; l < k; l++) {
|
|
1609
2502
|
__builtin_mma_xxsetaccz(&acc_0);
|
|
1610
2503
|
__builtin_mma_xxsetaccz(&acc_1);
|
|
1611
2504
|
__builtin_mma_xxsetaccz(&acc_2);
|
|
1612
2505
|
__builtin_mma_xxsetaccz(&acc_3);
|
|
1613
|
-
|
|
2506
|
+
if (std::is_same_v<TA, block_q4_0>) {
|
|
2507
|
+
packNormalInt4<int8_t, vector signed char, 8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
|
|
2508
|
+
} else {
|
|
2509
|
+
packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
|
|
2510
|
+
}
|
|
1614
2511
|
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
|
|
1615
2512
|
for(int x = 0; x < 8; x++) {
|
|
1616
2513
|
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
|
@@ -1624,15 +2521,17 @@ class tinyBLAS_Q0_PPC {
|
|
|
1624
2521
|
*((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
|
|
1625
2522
|
}
|
|
1626
2523
|
}
|
|
1627
|
-
|
|
1628
|
-
|
|
1629
|
-
|
|
1630
|
-
|
|
1631
|
-
|
|
1632
|
-
|
|
1633
|
-
|
|
1634
|
-
|
|
1635
|
-
|
|
2524
|
+
if (!isAblock_q4) {
|
|
2525
|
+
auto aoffset = A+(ii*lda)+l;
|
|
2526
|
+
for (int i = 0; i < 8; i++) {
|
|
2527
|
+
comparray[i] = 0;
|
|
2528
|
+
int ca = 0;
|
|
2529
|
+
auto *at = aoffset->qs;
|
|
2530
|
+
for (int j = 0; j < 32; j++)
|
|
2531
|
+
ca += (int)*at++;
|
|
2532
|
+
comparray[i] = ca;
|
|
2533
|
+
aoffset += lda;
|
|
2534
|
+
}
|
|
1636
2535
|
}
|
|
1637
2536
|
compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
|
|
1638
2537
|
compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
|
|
@@ -1653,16 +2552,17 @@ class tinyBLAS_Q0_PPC {
|
|
|
1653
2552
|
int64_t duty = (tiles + nth - 1) / nth;
|
|
1654
2553
|
int64_t start = duty * ith;
|
|
1655
2554
|
int64_t end = start + duty;
|
|
1656
|
-
vec_t vec_A[8], vec_B[8] = {0};
|
|
2555
|
+
vec_t vec_A[8] = {0}, vec_B[8] = {0};
|
|
1657
2556
|
vector signed int vec_C[4];
|
|
1658
2557
|
acc_t acc_0;
|
|
2558
|
+
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
|
1659
2559
|
|
|
1660
2560
|
if (end > tiles)
|
|
1661
2561
|
end = tiles;
|
|
1662
2562
|
for (int64_t job = start; job < end; ++job) {
|
|
1663
2563
|
int64_t ii = m0 + job / xtiles * RM;
|
|
1664
2564
|
int64_t jj = n0 + job % xtiles * RN;
|
|
1665
|
-
std::array<int,
|
|
2565
|
+
std::array<int, 4> comparray{};
|
|
1666
2566
|
vector float res[4] = {0};
|
|
1667
2567
|
vector float fin_res[4] = {0};
|
|
1668
2568
|
vector float vs[4] = {0};
|
|
@@ -1673,7 +2573,11 @@ class tinyBLAS_Q0_PPC {
|
|
|
1673
2573
|
__builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
|
|
1674
2574
|
__builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
|
|
1675
2575
|
__builtin_mma_xxsetaccz(&acc_0);
|
|
1676
|
-
|
|
2576
|
+
if (isAblock_q4) {
|
|
2577
|
+
packNormalInt4<int8_t, vector signed char, 4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
|
|
2578
|
+
} else {
|
|
2579
|
+
packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
|
|
2580
|
+
}
|
|
1677
2581
|
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
|
|
1678
2582
|
for(int x = 0; x < 8; x+=4) {
|
|
1679
2583
|
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
|
@@ -1687,17 +2591,18 @@ class tinyBLAS_Q0_PPC {
|
|
|
1687
2591
|
}
|
|
1688
2592
|
}
|
|
1689
2593
|
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
|
1690
|
-
|
|
1691
|
-
|
|
1692
|
-
|
|
1693
|
-
|
|
1694
|
-
|
|
1695
|
-
|
|
1696
|
-
|
|
1697
|
-
|
|
1698
|
-
|
|
2594
|
+
if (!isAblock_q4) {
|
|
2595
|
+
auto aoffset = A+(ii*lda)+l;
|
|
2596
|
+
for (int i = 0; i < RM; i++) {
|
|
2597
|
+
comparray[i] = 0;
|
|
2598
|
+
int ca = 0;
|
|
2599
|
+
auto *at = aoffset->qs;
|
|
2600
|
+
for (int j = 0; j < 32; j++)
|
|
2601
|
+
ca += (int)*at++;
|
|
2602
|
+
comparray[i] = ca;
|
|
2603
|
+
aoffset += lda;
|
|
2604
|
+
}
|
|
1699
2605
|
}
|
|
1700
|
-
|
|
1701
2606
|
for (int i = 0; i < RM; i++) {
|
|
1702
2607
|
CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0));
|
|
1703
2608
|
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
|
|
@@ -1784,6 +2689,7 @@ class tinyBLAS_PPC {
|
|
|
1784
2689
|
boffset = vec;
|
|
1785
2690
|
j = (rows >> 3);
|
|
1786
2691
|
if (j > 0) {
|
|
2692
|
+
|
|
1787
2693
|
do {
|
|
1788
2694
|
aoffset1 = aoffset;
|
|
1789
2695
|
aoffset2 = aoffset1 + lda;
|
|
@@ -2013,6 +2919,7 @@ class tinyBLAS_PPC {
|
|
|
2013
2919
|
}
|
|
2014
2920
|
}
|
|
2015
2921
|
}
|
|
2922
|
+
|
|
2016
2923
|
void KERNEL_4x4(int64_t ii, int64_t jj) {
|
|
2017
2924
|
vec_t vec_A[4], vec_B[4], vec_C[4];
|
|
2018
2925
|
acc_t acc_0;
|
|
@@ -2259,15 +3166,27 @@ class tinyBLAS_PPC {
|
|
|
2259
3166
|
vec_t vec_C[4];
|
|
2260
3167
|
acc_t acc_0;
|
|
2261
3168
|
__builtin_mma_xxsetaccz(&acc_0);
|
|
2262
|
-
vec_t vec_A[4], vec_B[4];
|
|
3169
|
+
vec_t vec_A[4] {0}, vec_B[4] = {0};
|
|
2263
3170
|
for (int l=0; l<k; l+=4) {
|
|
2264
|
-
|
|
3171
|
+
/* 'GEMV Forwarding' concept is used in first two conditional loops.
|
|
3172
|
+
* when one of the matrix has a single row/column, the elements are
|
|
3173
|
+
* broadcasted, instead of using packing routine to prepack the
|
|
3174
|
+
* matrix elements.
|
|
3175
|
+
*/
|
|
3176
|
+
if (RM == 1) {
|
|
2265
3177
|
TA* a = const_cast<TA*>(A+(ii)*lda+l);
|
|
2266
|
-
packTranspose<vector float>(B+(jj*ldb)+l, ldb,
|
|
3178
|
+
packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
|
|
2267
3179
|
vec_A[0] = (vec_t)vec_xl(0,a);
|
|
2268
3180
|
vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
|
|
2269
3181
|
vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
|
|
2270
3182
|
vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
|
|
3183
|
+
} else if (RN == 1) {
|
|
3184
|
+
packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
|
|
3185
|
+
TB* b = const_cast<TB*>(B+(jj)*ldb+l);
|
|
3186
|
+
vec_B[0] = (vec_t)vec_xl(0,b);
|
|
3187
|
+
vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1));
|
|
3188
|
+
vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2));
|
|
3189
|
+
vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3));
|
|
2271
3190
|
} else {
|
|
2272
3191
|
packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
|
|
2273
3192
|
packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
|
|
@@ -2371,8 +3290,10 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
|
|
2371
3290
|
assert(params->ith < params->nth);
|
|
2372
3291
|
|
|
2373
3292
|
// only enable sgemm for prompt processing
|
|
3293
|
+
#if !defined(__MMA__)
|
|
2374
3294
|
if (n < 2)
|
|
2375
3295
|
return false;
|
|
3296
|
+
#endif
|
|
2376
3297
|
|
|
2377
3298
|
if (Ctype != GGML_TYPE_F32)
|
|
2378
3299
|
return false;
|
|
@@ -2442,9 +3363,22 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
|
|
2442
3363
|
(float *)C, ldc};
|
|
2443
3364
|
return tb.matmul(m, n);
|
|
2444
3365
|
}
|
|
3366
|
+
#elif defined(__MMA__)
|
|
3367
|
+
if ((k % 8))
|
|
3368
|
+
return false;
|
|
3369
|
+
if(Btype == GGML_TYPE_BF16) {
|
|
3370
|
+
tinyBLAS_BF16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
|
|
3371
|
+
(const ggml_bf16_t *)A, lda,
|
|
3372
|
+
(const ggml_bf16_t *)B, ldb,
|
|
3373
|
+
(float *)C, ldc,
|
|
3374
|
+
params->ith, params->nth};
|
|
3375
|
+
tb.matmul(m, n);
|
|
3376
|
+
return true;
|
|
3377
|
+
}
|
|
2445
3378
|
#endif
|
|
2446
3379
|
return false;
|
|
2447
3380
|
}
|
|
3381
|
+
|
|
2448
3382
|
case GGML_TYPE_F16: {
|
|
2449
3383
|
#if defined(__AVX512F__)
|
|
2450
3384
|
if (Btype == GGML_TYPE_F16) {
|
|
@@ -2503,8 +3437,8 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
|
|
2503
3437
|
params->ith, params->nth};
|
|
2504
3438
|
tb.matmul(m, n);
|
|
2505
3439
|
return true;
|
|
2506
|
-
|
|
2507
3440
|
#elif defined(__MMA__)
|
|
3441
|
+
//TO-DO: Remove this condition once gemv forwarding is enabled.
|
|
2508
3442
|
if (n < 8 && n != 4)
|
|
2509
3443
|
return false;
|
|
2510
3444
|
if (m < 8 && m != 4)
|
|
@@ -2516,7 +3450,6 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
|
|
2516
3450
|
params->ith, params->nth};
|
|
2517
3451
|
tb.matmul(m, n);
|
|
2518
3452
|
return true;
|
|
2519
|
-
|
|
2520
3453
|
#else
|
|
2521
3454
|
return false;
|
|
2522
3455
|
#endif
|
|
@@ -2541,6 +3474,19 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
|
|
2541
3474
|
params->ith, params->nth};
|
|
2542
3475
|
tb.matmul(m, n);
|
|
2543
3476
|
return true;
|
|
3477
|
+
#elif defined(__MMA__)
|
|
3478
|
+
//TO-DO: Remove this condition once gemv forwarding is enabled.
|
|
3479
|
+
if (n < 8 && n != 4)
|
|
3480
|
+
return false;
|
|
3481
|
+
if (m < 8 && m != 4)
|
|
3482
|
+
return false;
|
|
3483
|
+
tinyBLAS_Q0_PPC<block_q4_0, block_q8_0, float> tb{
|
|
3484
|
+
k, (const block_q4_0 *)A, lda,
|
|
3485
|
+
(const block_q8_0 *)B, ldb,
|
|
3486
|
+
(float *)C, ldc,
|
|
3487
|
+
params->ith, params->nth};
|
|
3488
|
+
tb.matmul(m, n);
|
|
3489
|
+
return true;
|
|
2544
3490
|
#else
|
|
2545
3491
|
return false;
|
|
2546
3492
|
#endif
|