@fugood/llama.node 0.3.3 → 0.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (225) hide show
  1. package/CMakeLists.txt +5 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +29 -1
  17. package/package.json +1 -1
  18. package/src/EmbeddingWorker.cpp +15 -5
  19. package/src/EmbeddingWorker.h +2 -1
  20. package/src/LlamaCompletionWorker.cpp +17 -1
  21. package/src/LlamaContext.cpp +86 -18
  22. package/src/LlamaContext.h +2 -0
  23. package/src/llama.cpp/.github/workflows/build.yml +197 -159
  24. package/src/llama.cpp/.github/workflows/docker.yml +5 -8
  25. package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
  26. package/src/llama.cpp/.github/workflows/server.yml +21 -14
  27. package/src/llama.cpp/CMakeLists.txt +11 -6
  28. package/src/llama.cpp/Sources/llama/llama.h +4 -0
  29. package/src/llama.cpp/cmake/common.cmake +33 -0
  30. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
  31. package/src/llama.cpp/common/CMakeLists.txt +6 -2
  32. package/src/llama.cpp/common/arg.cpp +426 -245
  33. package/src/llama.cpp/common/common.cpp +143 -80
  34. package/src/llama.cpp/common/common.h +81 -24
  35. package/src/llama.cpp/common/sampling.cpp +53 -19
  36. package/src/llama.cpp/common/sampling.h +22 -1
  37. package/src/llama.cpp/common/speculative.cpp +274 -0
  38. package/src/llama.cpp/common/speculative.h +28 -0
  39. package/src/llama.cpp/docs/build.md +101 -148
  40. package/src/llama.cpp/examples/CMakeLists.txt +32 -13
  41. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  42. package/src/llama.cpp/examples/batched/batched.cpp +5 -4
  43. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  44. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  45. package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
  46. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
  47. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  48. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
  49. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  50. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
  51. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
  52. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
  53. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  54. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
  55. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  56. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
  57. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  58. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  59. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  60. package/src/llama.cpp/examples/imatrix/imatrix.cpp +11 -2
  61. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  62. package/src/llama.cpp/examples/infill/infill.cpp +1 -1
  63. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
  64. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +405 -316
  65. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  66. package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
  67. package/src/llama.cpp/examples/llava/clip.cpp +262 -66
  68. package/src/llama.cpp/examples/llava/clip.h +8 -2
  69. package/src/llama.cpp/examples/llava/llava-cli.cpp +1 -1
  70. package/src/llama.cpp/examples/llava/llava.cpp +46 -19
  71. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +1 -1
  72. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
  73. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  74. package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -1
  75. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  76. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +2 -1
  77. package/src/llama.cpp/examples/lookup/lookup.cpp +2 -2
  78. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  79. package/src/llama.cpp/examples/main/main.cpp +9 -5
  80. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
  81. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  82. package/src/llama.cpp/examples/parallel/parallel.cpp +1 -1
  83. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  84. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  85. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  86. package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
  87. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  88. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  89. package/src/llama.cpp/examples/retrieval/retrieval.cpp +4 -4
  90. package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
  91. package/src/llama.cpp/examples/run/run.cpp +911 -0
  92. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  93. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -4
  94. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -7
  95. package/src/llama.cpp/examples/server/server.cpp +1758 -886
  96. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  97. package/src/llama.cpp/examples/server/utils.hpp +94 -304
  98. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  99. package/src/llama.cpp/examples/simple/simple.cpp +4 -0
  100. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +1 -1
  101. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +3 -0
  102. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  103. package/src/llama.cpp/examples/speculative/speculative.cpp +16 -15
  104. package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
  105. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
  106. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  107. package/src/llama.cpp/examples/tokenize/tokenize.cpp +1 -1
  108. package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
  109. package/src/llama.cpp/examples/tts/tts.cpp +932 -0
  110. package/src/llama.cpp/ggml/CMakeLists.txt +46 -34
  111. package/src/llama.cpp/ggml/include/ggml-backend.h +16 -0
  112. package/src/llama.cpp/ggml/include/ggml-cpu.h +7 -49
  113. package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
  114. package/src/llama.cpp/ggml/include/ggml.h +106 -24
  115. package/src/llama.cpp/ggml/src/CMakeLists.txt +73 -24
  116. package/src/llama.cpp/ggml/src/ggml-alloc.c +0 -1
  117. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +51 -11
  118. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +379 -22
  119. package/src/llama.cpp/ggml/src/ggml-backend.cpp +4 -4
  120. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -7
  121. package/src/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +5 -2
  122. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +33 -3
  123. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
  124. package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
  125. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +95 -35
  126. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
  127. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
  128. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
  129. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
  130. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
  131. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
  132. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
  133. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
  134. package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
  135. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +288 -213
  136. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  137. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  138. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/common.h +19 -22
  139. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.cpp +93 -92
  140. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.h +2 -9
  141. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  142. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +892 -190
  143. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +2 -24
  144. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  145. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  146. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +15 -0
  147. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +38 -25
  148. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  149. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  150. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +552 -399
  151. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +101 -136
  152. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +2 -2
  153. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +7 -10
  154. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
  155. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -6
  156. package/src/llama.cpp/ggml/src/ggml-impl.h +32 -11
  157. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +13 -9
  158. package/src/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +131 -64
  159. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +3 -6
  160. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +39 -0
  161. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +14 -7
  162. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
  163. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
  164. package/src/llama.cpp/ggml/src/ggml-opt.cpp +67 -80
  165. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -9
  166. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +3 -5
  167. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +5 -2
  168. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +13 -10
  169. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +2 -11
  170. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -0
  171. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +2 -2
  172. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  173. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
  174. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +32 -13
  175. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +80 -61
  176. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
  177. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +159 -114
  178. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
  179. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
  180. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +6 -20
  181. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +4 -3
  182. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +8 -8
  183. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
  184. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
  185. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +4 -1
  187. package/src/llama.cpp/ggml/src/ggml-threading.h +4 -2
  188. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +21 -7
  189. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1718 -399
  190. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +3 -1
  191. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +105 -31
  192. package/src/llama.cpp/ggml/src/ggml.c +367 -207
  193. package/src/llama.cpp/include/llama-cpp.h +25 -0
  194. package/src/llama.cpp/include/llama.h +26 -19
  195. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
  196. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
  197. package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
  198. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  199. package/src/llama.cpp/src/CMakeLists.txt +2 -7
  200. package/src/llama.cpp/src/llama-grammar.cpp +15 -15
  201. package/src/llama.cpp/src/llama-grammar.h +2 -5
  202. package/src/llama.cpp/src/llama-sampling.cpp +35 -90
  203. package/src/llama.cpp/src/llama-vocab.cpp +6 -1
  204. package/src/llama.cpp/src/llama.cpp +1748 -640
  205. package/src/llama.cpp/src/unicode.cpp +62 -51
  206. package/src/llama.cpp/src/unicode.h +9 -10
  207. package/src/llama.cpp/tests/CMakeLists.txt +48 -37
  208. package/src/llama.cpp/tests/test-arg-parser.cpp +2 -2
  209. package/src/llama.cpp/tests/test-backend-ops.cpp +140 -21
  210. package/src/llama.cpp/tests/test-chat-template.cpp +50 -4
  211. package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
  212. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
  213. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
  214. package/src/llama.cpp/tests/test-quantize-fns.cpp +3 -3
  215. package/src/llama.cpp/tests/test-rope.cpp +61 -20
  216. package/src/llama.cpp/tests/test-sampling.cpp +2 -2
  217. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
  218. package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
  219. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
  220. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
  221. package/src/llama.cpp/ggml/include/ggml-amx.h +0 -25
  222. package/src/llama.cpp/ggml/src/ggml-aarch64.c +0 -129
  223. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -19
  224. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  225. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
@@ -1,24 +1,57 @@
1
- // SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
2
- // SPDX-License-Identifier: MIT
3
- //
4
-
5
- #define GGML_COMMON_IMPL_C
1
+ #define GGML_COMMON_IMPL_CPP
2
+ #define GGML_COMMON_DECL_CPP
6
3
  #include "ggml-common.h"
4
+ #include "ggml-backend-impl.h"
7
5
 
8
6
  #include "ggml-quants.h"
9
7
  #include "ggml-impl.h"
10
8
  #include "ggml-cpu.h"
11
- #include "ggml-cpu/ggml-cpu-impl.h"
9
+ #include "ggml-cpu-impl.h"
10
+ #include "ggml-cpu-traits.h"
12
11
 
13
- #include <math.h>
14
- #include <string.h>
15
- #include <assert.h>
16
- #include <float.h>
17
- #include <stdlib.h> // for qsort
18
- #include <stdio.h> // for GGML_ASSERT
12
+ #include <cmath>
13
+ #include <cstring>
14
+ #include <cassert>
15
+ #include <cfloat>
16
+ #include <cstdlib> // for qsort
17
+ #include <cstdio> // for GGML_ASSERT
19
18
 
20
19
  #include "ggml-cpu-aarch64.h"
21
20
 
21
+ // TODO: move to include file?
22
+ template <int K> constexpr int QK_0() {
23
+ if constexpr (K == 4) {
24
+ return QK4_0;
25
+ }
26
+ if constexpr (K == 8) {
27
+ return QK8_0;
28
+ }
29
+ return -1;
30
+ }
31
+
32
+ template <int K, int N> struct block {
33
+ ggml_half d[N]; // deltas for N qK_0 blocks
34
+ int8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_0 blocks
35
+ };
36
+
37
+ // control size
38
+ static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding");
39
+ static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding");
40
+ static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding");
41
+ static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding");
42
+
43
+ using block_q4_0x4 = block<4, 4>;
44
+ using block_q4_0x8 = block<4, 8>;
45
+ using block_q8_0x4 = block<8, 4>;
46
+ using block_q8_0x8 = block<8, 8>;
47
+
48
+ struct block_iq4_nlx4 {
49
+ ggml_half d[4]; // deltas for 4 iq4_nl blocks
50
+ uint8_t qs[QK4_NL * 2]; // nibbles / quants for 4 iq4_nl blocks
51
+ };
52
+
53
+ static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
54
+
22
55
  #if defined(__GNUC__)
23
56
  #pragma GCC diagnostic ignored "-Woverlength-strings"
24
57
  #elif defined(_MSC_VER)
@@ -132,7 +165,7 @@ static inline __m512i sum_i16_pairs_int_32x16(const __m512i x) {
132
165
  }
133
166
 
134
167
  static inline __m512i mul_sum_us8_pairs_int32x16(const __m512i ax, const __m512i sy) {
135
- #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
168
+ #if defined(__AVX512VNNI__)
136
169
  const __m512i zero = _mm512_setzero_si512();
137
170
  return _mm512_dpbusd_epi32(zero, ax, sy);
138
171
  #else
@@ -187,12 +220,14 @@ static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y)
187
220
  }
188
221
  #endif
189
222
 
190
- static void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k) {
223
+ static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
224
+
225
+ static void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
191
226
  assert(QK8_0 == 32);
192
227
  assert(k % QK8_0 == 0);
193
228
  const int nb = k / QK8_0;
194
229
 
195
- block_q8_0x4 * restrict y = (block_q8_0x4 *) vy;
230
+ block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
196
231
 
197
232
  #if defined(__ARM_NEON)
198
233
  float32x4_t srcv[4][8];
@@ -281,12 +316,12 @@ static void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int6
281
316
  #endif
282
317
  }
283
318
 
284
- static void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k) {
319
+ static void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
285
320
  assert(QK8_0 == 32);
286
321
  assert(k % QK8_0 == 0);
287
322
  const int nb = k / QK8_0;
288
323
 
289
- block_q8_0x4 * restrict y = (block_q8_0x4 *) vy;
324
+ block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
290
325
 
291
326
  #if defined(__ARM_NEON)
292
327
  float32x4_t srcv[4][8];
@@ -496,7 +531,7 @@ static void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int6
496
531
  #endif
497
532
  }
498
533
 
499
- void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
534
+ static void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
500
535
  assert(nrow == 4);
501
536
  UNUSED(nrow);
502
537
  if (blck_size_interleave == 4) {
@@ -508,7 +543,7 @@ void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nro
508
543
  }
509
544
  }
510
545
 
511
- void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
546
+ static void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
512
547
  const int qk = QK8_0;
513
548
  const int nb = n / qk;
514
549
  const int ncols_interleaved = 4;
@@ -527,67 +562,47 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
527
562
  UNUSED(ncols_interleaved);
528
563
  UNUSED(blocklen);
529
564
 
530
- #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
531
- if (ggml_cpu_has_neon()) {
532
- const void * b_ptr = vx;
533
- const void * a_ptr = vy;
534
- float * res_ptr = s;
535
-
536
- __asm__ __volatile__(
537
- "movi v31.16b, #0x4\n"
538
- "movi v30.16b, #0xf0\n"
539
- "add %x[b_ptr], %x[b_ptr], #0x8\n"
540
- "1:" // Column loop
541
- "add x22, %x[a_ptr], #0x2\n"
542
- "movi v29.16b, #0x0\n"
543
- "mov x21, %x[nb]\n"
544
- "2:" // Block loop
545
- "ldr q28, [%x[b_ptr], #0x0]\n"
546
- "ldr q27, [x22, #0x0]\n"
547
- "movi v26.4s, #0x0\n"
548
- "sub x20, x22, #0x2\n"
549
- "ldr q25, [x22, #0x10]\n"
550
- "ldr q24, [%x[b_ptr], #0x10]\n"
551
- "sub x21, x21, #0x1\n"
552
- "add x22, x22, #0x22\n"
553
- "ldr q23, [%x[b_ptr], #0x20]\n"
554
- "ldr q22, [%x[b_ptr], #0x30]\n"
555
- "ld1r { v21.8h }, [x20]\n"
556
- "ldr q20, [%x[b_ptr], #-0x8]\n"
557
- "sshl v16.16b, v28.16b, v31.16b\n"
558
- "and v28.16b, v28.16b, v30.16b\n"
559
- "sshl v19.16b, v24.16b, v31.16b\n"
560
- "and v24.16b, v24.16b, v30.16b\n"
561
- "add %x[b_ptr], %x[b_ptr], #0x48\n"
562
- "sshl v18.16b, v23.16b, v31.16b\n"
563
- "and v23.16b, v23.16b, v30.16b\n"
564
- ".inst 0x4f9be21a // sdot v26.4s, v16.16b, v27.4b[0]\n"
565
- "sshl v17.16b, v22.16b, v31.16b\n"
566
- "and v22.16b, v22.16b, v30.16b\n"
567
- "fcvtl v21.4s, v21.4h\n"
568
- "fcvtl v16.4s, v20.4h\n"
569
- ".inst 0x4f99e39a // sdot v26.4s, v28.16b, v25.4b[0]\n"
570
- "fmul v16.4s, v16.4s, v21.4s\n"
571
- ".inst 0x4fbbe27a // sdot v26.4s, v19.16b, v27.4b[1]\n"
572
- ".inst 0x4fb9e31a // sdot v26.4s, v24.16b, v25.4b[1]\n"
573
- ".inst 0x4f9bea5a // sdot v26.4s, v18.16b, v27.4b[2]\n"
574
- ".inst 0x4f99eafa // sdot v26.4s, v23.16b, v25.4b[2]\n"
575
- ".inst 0x4fbbea3a // sdot v26.4s, v17.16b, v27.4b[3]\n"
576
- ".inst 0x4fb9eada // sdot v26.4s, v22.16b, v25.4b[3]\n"
577
- "scvtf v26.4s, v26.4s, #0x4\n"
578
- "fmla v29.4s, v26.4s, v16.4s\n"
579
- "cbnz x21, 2b\n"
580
- "sub %x[nc], %x[nc], #0x4\n"
581
- "str q29, [%x[res_ptr], #0x0]\n"
582
- "add %x[res_ptr], %x[res_ptr], #0x10\n"
583
- "cbnz %x[nc], 1b\n"
584
- : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
585
- : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
586
- : "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22"
587
- );
565
+ #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
566
+ if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
567
+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 *)vx;
568
+
569
+ for (int c = 0; c < nc; c += ncols_interleaved) {
570
+ const block_q8_0 * a_ptr = (const block_q8_0 *)vy;
571
+ float32x4_t acc = vdupq_n_f32(0);
572
+ for (int b = 0; b < nb; b++) {
573
+ int8x16_t b0 = vld1q_s8((const int8_t *)b_ptr->qs);
574
+ int8x16_t b1 = vld1q_s8((const int8_t *)b_ptr->qs + 16);
575
+ int8x16_t b2 = vld1q_s8((const int8_t *)b_ptr->qs + 32);
576
+ int8x16_t b3 = vld1q_s8((const int8_t *)b_ptr->qs + 48);
577
+ float16x4_t bd = vld1_f16((const __fp16 *)b_ptr->d);
578
+
579
+ int8x16_t a0 = vld1q_s8(a_ptr->qs);
580
+ int8x16_t a1 = vld1q_s8(a_ptr->qs + qk/2);
581
+ float16x4_t ad = vld1_dup_f16((const __fp16 *)&a_ptr->d);
582
+
583
+ int32x4_t ret = vdupq_n_s32(0);
584
+
585
+ ret = vdotq_laneq_s32(ret, b0 << 4, a0, 0);
586
+ ret = vdotq_laneq_s32(ret, b1 << 4, a0, 1);
587
+ ret = vdotq_laneq_s32(ret, b2 << 4, a0, 2);
588
+ ret = vdotq_laneq_s32(ret, b3 << 4, a0, 3);
589
+
590
+ ret = vdotq_laneq_s32(ret, b0 & 0xf0U, a1, 0);
591
+ ret = vdotq_laneq_s32(ret, b1 & 0xf0U, a1, 1);
592
+ ret = vdotq_laneq_s32(ret, b2 & 0xf0U, a1, 2);
593
+ ret = vdotq_laneq_s32(ret, b3 & 0xf0U, a1, 3);
594
+
595
+ acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4),
596
+ vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
597
+ a_ptr++;
598
+ b_ptr++;
599
+ }
600
+ vst1q_f32(s, acc);
601
+ s += ncols_interleaved;
602
+ }
588
603
  return;
589
604
  }
590
- #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
605
+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
591
606
  float sumf[4];
592
607
  int sumi;
593
608
 
@@ -613,7 +628,7 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
613
628
  }
614
629
  }
615
630
 
616
- void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
631
+ static void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
617
632
  const int qk = QK8_0;
618
633
  const int nb = n / qk;
619
634
  const int ncols_interleaved = 4;
@@ -723,7 +738,7 @@ void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void *
723
738
  }
724
739
  }
725
740
 
726
- void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
741
+ static void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
727
742
  const int qk = QK8_0;
728
743
  const int nb = n / qk;
729
744
  const int ncols_interleaved = 8;
@@ -996,7 +1011,103 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
996
1011
  }
997
1012
  }
998
1013
 
999
- void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
1014
+ static void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
1015
+ const int qk = QK8_0;
1016
+ const int nb = n / qk;
1017
+ const int ncols_interleaved = 4;
1018
+ const int blocklen = 4;
1019
+
1020
+ assert (n % qk == 0);
1021
+ assert (nc % ncols_interleaved == 0);
1022
+
1023
+ UNUSED(s);
1024
+ UNUSED(bs);
1025
+ UNUSED(vx);
1026
+ UNUSED(vy);
1027
+ UNUSED(nr);
1028
+ UNUSED(nc);
1029
+ UNUSED(nb);
1030
+ UNUSED(ncols_interleaved);
1031
+ UNUSED(blocklen);
1032
+
1033
+ #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
1034
+ if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
1035
+ const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
1036
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
1037
+ float * res_ptr = s;
1038
+
1039
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
1040
+ const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
1041
+
1042
+ float32x4_t sumf = vdupq_n_f32(0);
1043
+ for (int l = 0; l < nb; l++) {
1044
+ uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);
1045
+ uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);
1046
+ uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);
1047
+ uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);
1048
+
1049
+ int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4);
1050
+ int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F);
1051
+ int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4);
1052
+ int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F);
1053
+ int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4);
1054
+ int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F);
1055
+ int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4);
1056
+ int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F);
1057
+
1058
+ int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);
1059
+ int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);
1060
+
1061
+ int32x4_t sumi = vdupq_n_s32(0);
1062
+ sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0);
1063
+ sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0);
1064
+ sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1);
1065
+ sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1);
1066
+ sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2);
1067
+ sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);
1068
+ sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);
1069
+ sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);
1070
+
1071
+ float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));
1072
+ float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
1073
+ float32x4_t d = a_d * b_d;
1074
+
1075
+ sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi));
1076
+ }
1077
+
1078
+ vst1q_f32(res_ptr + x * 4, sumf);
1079
+ }
1080
+ return;
1081
+ }
1082
+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
1083
+ {
1084
+ float sumf[4];
1085
+ int sumi;
1086
+
1087
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
1088
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
1089
+ const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
1090
+
1091
+ for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
1092
+ for (int l = 0; l < nb; l++) {
1093
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
1094
+ for (int j = 0; j < ncols_interleaved; j++) {
1095
+ sumi = 0;
1096
+ for (int i = 0; i < blocklen; ++i) {
1097
+ const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
1098
+ const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
1099
+ sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
1100
+ }
1101
+ sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
1102
+ }
1103
+ }
1104
+ }
1105
+ for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
1106
+ }
1107
+ }
1108
+ }
1109
+
1110
+ static void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
1000
1111
  const int qk = QK8_0;
1001
1112
  const int nb = n / qk;
1002
1113
  const int ncols_interleaved = 4;
@@ -1017,7 +1128,7 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
1017
1128
  UNUSED(blocklen);
1018
1129
 
1019
1130
  #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
1020
- if (ggml_cpu_has_neon()) {
1131
+ if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
1021
1132
  const void * b_ptr = vx;
1022
1133
  const void * a_ptr = vy;
1023
1134
  float * res_ptr = s;
@@ -1512,7 +1623,7 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
1512
1623
  }
1513
1624
  }
1514
1625
 
1515
- void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
1626
+ static void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
1516
1627
  const int qk = QK8_0;
1517
1628
  const int nb = n / qk;
1518
1629
  const int ncols_interleaved = 4;
@@ -1966,7 +2077,7 @@ void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void *
1966
2077
  }
1967
2078
  }
1968
2079
 
1969
- void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
2080
+ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
1970
2081
  const int qk = QK8_0;
1971
2082
  const int nb = n / qk;
1972
2083
  const int ncols_interleaved = 8;
@@ -2486,31 +2597,31 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
2486
2597
  const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
2487
2598
 
2488
2599
  // Shuffle pattern one - right side input
2489
- const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
2490
- const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
2600
+ const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
2601
+ const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
2491
2602
 
2492
- const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
2493
- const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
2603
+ const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
2604
+ const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
2494
2605
 
2495
- const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
2496
- const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
2606
+ const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
2607
+ const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
2497
2608
 
2498
- const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
2499
- const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
2609
+ const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
2610
+ const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
2500
2611
 
2501
2612
  // Shuffle pattern two - right side input
2502
2613
 
2503
- const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
2504
- const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
2614
+ const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
2615
+ const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
2505
2616
 
2506
- const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
2507
- const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
2617
+ const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
2618
+ const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
2508
2619
 
2509
- const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
2510
- const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
2620
+ const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
2621
+ const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
2511
2622
 
2512
- const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
2513
- const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
2623
+ const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
2624
+ const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
2514
2625
 
2515
2626
  // Scale values - Load the weight scale values of two block_q4_0x8
2516
2627
  const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
@@ -2544,31 +2655,31 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
2544
2655
 
2545
2656
  // Shuffle pattern one - left side input
2546
2657
 
2547
- const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
2548
- const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
2658
+ const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
2659
+ const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
2549
2660
 
2550
- const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
2551
- const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
2661
+ const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
2662
+ const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
2552
2663
 
2553
- const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
2554
- const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
2664
+ const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
2665
+ const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
2555
2666
 
2556
- const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
2557
- const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
2667
+ const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
2668
+ const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
2558
2669
 
2559
2670
  // Shuffle pattern two - left side input
2560
2671
 
2561
- const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
2562
- const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
2672
+ const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
2673
+ const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
2563
2674
 
2564
- const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
2565
- const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
2675
+ const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
2676
+ const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
2566
2677
 
2567
- const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
2568
- const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
2678
+ const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
2679
+ const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
2569
2680
 
2570
- const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
2571
- const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
2681
+ const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
2682
+ const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
2572
2683
 
2573
2684
  // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
2574
2685
  // Resembles MMLAs into 2x2 matrices in ARM Version
@@ -2597,10 +2708,10 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
2597
2708
 
2598
2709
 
2599
2710
  // Straighten out to make 4 row vectors
2600
- __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78));
2601
- __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01);
2602
- __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78));
2603
- __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11);
2711
+ __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
2712
+ __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
2713
+ __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
2714
+ __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
2604
2715
 
2605
2716
  // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
2606
2717
  const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68);
@@ -2679,31 +2790,31 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
2679
2790
  const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
2680
2791
 
2681
2792
  // Shuffle pattern one - right side input
2682
- const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
2683
- const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
2793
+ const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
2794
+ const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
2684
2795
 
2685
- const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
2686
- const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
2796
+ const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
2797
+ const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
2687
2798
 
2688
- const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
2689
- const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
2799
+ const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
2800
+ const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
2690
2801
 
2691
- const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
2692
- const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
2802
+ const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
2803
+ const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
2693
2804
 
2694
2805
  // Shuffle pattern two - right side input
2695
2806
 
2696
- const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
2697
- const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
2807
+ const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
2808
+ const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
2698
2809
 
2699
- const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
2700
- const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
2810
+ const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
2811
+ const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
2701
2812
 
2702
- const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
2703
- const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
2813
+ const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
2814
+ const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
2704
2815
 
2705
- const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
2706
- const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
2816
+ const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
2817
+ const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
2707
2818
 
2708
2819
 
2709
2820
  // Scale values - Load the weight scale values of two block_q4_0x8
@@ -2735,31 +2846,31 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
2735
2846
 
2736
2847
  // Shuffle pattern one - left side input
2737
2848
 
2738
- const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
2739
- const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
2849
+ const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
2850
+ const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
2740
2851
 
2741
- const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
2742
- const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
2852
+ const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
2853
+ const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
2743
2854
 
2744
- const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
2745
- const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
2855
+ const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
2856
+ const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
2746
2857
 
2747
- const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
2748
- const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
2858
+ const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
2859
+ const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
2749
2860
 
2750
2861
  // Shuffle pattern two - left side input
2751
2862
 
2752
- const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
2753
- const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
2863
+ const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
2864
+ const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
2754
2865
 
2755
- const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
2756
- const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
2866
+ const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
2867
+ const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
2757
2868
 
2758
- const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
2759
- const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
2869
+ const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
2870
+ const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
2760
2871
 
2761
- const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
2762
- const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
2872
+ const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
2873
+ const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
2763
2874
 
2764
2875
  // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
2765
2876
  // Resembles MMLAs into 2x2 matrices in ARM Version
@@ -2788,10 +2899,10 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
2788
2899
 
2789
2900
 
2790
2901
  // Straighten out to make 4 row vectors
2791
- __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78));
2792
- __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01);
2793
- __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78));
2794
- __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11);
2902
+ __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
2903
+ __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
2904
+ __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
2905
+ __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
2795
2906
 
2796
2907
  // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
2797
2908
  const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68);
@@ -3386,7 +3497,117 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
3386
3497
  }
3387
3498
  }
3388
3499
 
3389
- // FIXME: this code is duplicated from ggml-aarch64.c
3500
+ static void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
3501
+ const int qk = QK8_0;
3502
+ const int nb = n / qk;
3503
+ const int ncols_interleaved = 4;
3504
+ const int blocklen = 4;
3505
+
3506
+ assert (n % qk == 0);
3507
+ assert (nr % 4 == 0);
3508
+ assert (nc % ncols_interleaved == 0);
3509
+
3510
+ UNUSED(s);
3511
+ UNUSED(bs);
3512
+ UNUSED(vx);
3513
+ UNUSED(vy);
3514
+ UNUSED(nr);
3515
+ UNUSED(nc);
3516
+ UNUSED(nb);
3517
+ UNUSED(ncols_interleaved);
3518
+ UNUSED(blocklen);
3519
+
3520
+ #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
3521
+ if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
3522
+ const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
3523
+
3524
+ for (int y = 0; y < nr / 4; y++) {
3525
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
3526
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
3527
+ const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
3528
+
3529
+ float32x4_t sumf[4];
3530
+ for (int m = 0; m < 4; m++) {
3531
+ sumf[m] = vdupq_n_f32(0);
3532
+ }
3533
+
3534
+ for (int l = 0; l < nb; l++) {
3535
+ float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
3536
+ float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
3537
+
3538
+ int32x4_t sumi_0 = vdupq_n_s32(0);
3539
+ int32x4_t sumi_1 = vdupq_n_s32(0);
3540
+ int32x4_t sumi_2 = vdupq_n_s32(0);
3541
+ int32x4_t sumi_3 = vdupq_n_s32(0);
3542
+
3543
+ for (int k = 0; k < 4; k++) {
3544
+ int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
3545
+ int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
3546
+
3547
+ uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
3548
+ int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
3549
+ int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
3550
+
3551
+ sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
3552
+ sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
3553
+ sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
3554
+ sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
3555
+ sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
3556
+ sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
3557
+ sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
3558
+ sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
3559
+ }
3560
+
3561
+ sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
3562
+ sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
3563
+ sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
3564
+ sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
3565
+ }
3566
+
3567
+ for (int m = 0; m < 4; m++) {
3568
+ vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
3569
+ }
3570
+ }
3571
+ }
3572
+ return;
3573
+ }
3574
+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
3575
+ {
3576
+ float sumf[4][4];
3577
+ int sumi;
3578
+
3579
+ for (int y = 0; y < nr / 4; y++) {
3580
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
3581
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
3582
+ const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
3583
+ for (int m = 0; m < 4; m++) {
3584
+ for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
3585
+ }
3586
+ for (int l = 0; l < nb; l++) {
3587
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
3588
+ for (int m = 0; m < 4; m++) {
3589
+ for (int j = 0; j < ncols_interleaved; j++) {
3590
+ sumi = 0;
3591
+ for (int i = 0; i < blocklen; ++i) {
3592
+ const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
3593
+ const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
3594
+ sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
3595
+ (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
3596
+ }
3597
+ sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
3598
+ }
3599
+ }
3600
+ }
3601
+ }
3602
+ for (int m = 0; m < 4; m++) {
3603
+ for (int j = 0; j < ncols_interleaved; j++)
3604
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
3605
+ }
3606
+ }
3607
+ }
3608
+ }
3609
+ }
3610
+
3390
3611
  static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
3391
3612
  block_q4_0x4 out;
3392
3613
 
@@ -3456,20 +3677,20 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in
3456
3677
  return out;
3457
3678
  }
3458
3679
 
3459
- static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * restrict data, size_t data_size) {
3680
+ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
3460
3681
  GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
3461
3682
  GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
3683
+ constexpr int nrows_interleaved = 4;
3462
3684
 
3463
3685
  block_q4_0x4 * dst = (block_q4_0x4 *)t->data;
3464
3686
  const block_q4_0 * src = (const block_q4_0 *)data;
3465
3687
  block_q4_0 dst_tmp[4];
3466
- int nrow = t->ne[1]; // Number of rows
3467
- int nrows_interleaved = 4;
3688
+ int nrow = ggml_nrows(t);
3468
3689
  int nblocks = t->ne[0] / QK4_0;
3469
3690
 
3470
3691
  GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
3471
3692
 
3472
- if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
3693
+ if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
3473
3694
  return -1;
3474
3695
  }
3475
3696
 
@@ -3487,20 +3708,20 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block
3487
3708
  GGML_UNUSED(data_size);
3488
3709
  }
3489
3710
 
3490
- static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor *t, int interleave_block, const void * restrict data, size_t data_size) {
3711
+ static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
3491
3712
  GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
3492
3713
  GGML_ASSERT(interleave_block == 8);
3714
+ constexpr int nrows_interleaved = 8;
3493
3715
 
3494
3716
  block_q4_0x8 * dst = (block_q4_0x8*)t->data;
3495
3717
  const block_q4_0 * src = (const block_q4_0*) data;
3496
3718
  block_q4_0 dst_tmp[8];
3497
- int nrow = t->ne[1]; // Number of rows
3498
- int nrows_interleaved = 8;
3719
+ int nrow = ggml_nrows(t);
3499
3720
  int nblocks = t->ne[0] / QK4_0;
3500
3721
 
3501
3722
  GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
3502
3723
 
3503
- if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
3724
+ if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
3504
3725
  return -1;
3505
3726
  }
3506
3727
 
@@ -3518,43 +3739,524 @@ static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor *t, int interleave_block,
3518
3739
  GGML_UNUSED(data_size);
3519
3740
  }
3520
3741
 
3521
- // Prepare for optimized kernels if applicable
3522
- void ggml_aarch64_repack_tensor(struct ggml_tensor * cur, enum ggml_type repack_type, const void * restrict data, size_t data_size) {
3523
- if (cur->type == repack_type) {
3524
- memcpy(cur->data, data, data_size);
3525
- return;
3742
+ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) {
3743
+ block_iq4_nlx4 out;
3744
+
3745
+ for (int i = 0; i < 4; i++) {
3746
+ out.d[i] = in[i].d;
3526
3747
  }
3527
3748
 
3528
- GGML_ASSERT(cur->type == GGML_TYPE_Q4_0);
3749
+ const int end = QK4_NL * 2 / blck_size_interleave;
3529
3750
 
3530
- switch (repack_type) {
3531
- case GGML_TYPE_Q4_0_8_8:
3532
- repack_q4_0_to_q4_0_8_bl(cur, 8, data, data_size);
3533
- break;
3534
- case GGML_TYPE_Q4_0_4_8:
3535
- repack_q4_0_to_q4_0_4_bl(cur, 8, data, data_size);
3536
- break;
3537
- case GGML_TYPE_Q4_0_4_4:
3538
- repack_q4_0_to_q4_0_4_bl(cur, 4, data, data_size);
3751
+ // TODO: this branch seems wrong
3752
+ //if (blck_size_interleave == 8) {
3753
+ // for (int i = 0; i < end; ++i) {
3754
+ // int src_id = i % 4;
3755
+ // int src_offset = (i / 4) * blck_size_interleave;
3756
+ // int dst_offset = i * blck_size_interleave;
3757
+
3758
+ // // Using memcpy to avoid unaligned memory accesses
3759
+ // memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
3760
+ // }
3761
+ //} else
3762
+ if (blck_size_interleave == 4) {
3763
+ for (int i = 0; i < end; ++i) {
3764
+ int src_id = i % 4;
3765
+ int src_offset = (i / 4) * blck_size_interleave;
3766
+ int dst_offset = i * blck_size_interleave;
3767
+
3768
+ memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t));
3769
+ }
3770
+ } else {
3771
+ GGML_ASSERT(false);
3772
+ }
3773
+
3774
+ return out;
3775
+ }
3776
+
3777
+ static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
3778
+ GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);
3779
+ //GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
3780
+ GGML_ASSERT(interleave_block == 4);
3781
+
3782
+ block_iq4_nlx4 * dst = (block_iq4_nlx4 *)t->data;
3783
+ const block_iq4_nl * src = (const block_iq4_nl *)data;
3784
+ block_iq4_nl dst_tmp[4];
3785
+ int nrow = ggml_nrows(t);
3786
+ int nrows_interleaved = 4;
3787
+ int nblocks = t->ne[0] / QK4_0;
3788
+
3789
+ GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
3790
+
3791
+ if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
3792
+ return -1;
3793
+ }
3794
+
3795
+ for (int b = 0; b < nrow; b += nrows_interleaved) {
3796
+ for (int64_t x = 0; x < nblocks; x++) {
3797
+ for (int i = 0; i < nrows_interleaved; i++) {
3798
+ dst_tmp[i] = src[x + i * nblocks];
3799
+ }
3800
+ *dst++ = make_block_iq4_nlx4(dst_tmp, interleave_block);
3801
+ }
3802
+ src += nrows_interleaved * nblocks;
3803
+ }
3804
+ return 0;
3805
+
3806
+ GGML_UNUSED(data_size);
3807
+ }
3808
+
3809
+ namespace ggml::cpu::aarch64 {
3810
+ // repack
3811
+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
3812
+ int repack(struct ggml_tensor *, const void *, size_t);
3813
+
3814
+ // TODO: generalise.
3815
+ template <> int repack<block_q4_0, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
3816
+ return repack_q4_0_to_q4_0_4_bl(t, 4, data, data_size);
3817
+ }
3818
+
3819
+ template <> int repack<block_q4_0, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
3820
+ return repack_q4_0_to_q4_0_4_bl(t, 8, data, data_size);
3821
+ }
3822
+
3823
+ template <> int repack<block_q4_0, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
3824
+ return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size);
3825
+ }
3826
+
3827
+ template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
3828
+ return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
3829
+ }
3830
+
3831
+ // TODO: needs to be revisited
3832
+ //template <> int repack<block_iq4_nl, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
3833
+ // return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size);
3834
+ //}
3835
+
3836
+ // gemv
3837
+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
3838
+ void gemv(int, float *, size_t, const void *, const void *, int, int);
3839
+
3840
+ template <> void gemv<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3841
+ ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
3842
+ }
3843
+
3844
+ template <> void gemv<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3845
+ ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
3846
+ }
3847
+
3848
+ template <> void gemv<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3849
+ ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
3850
+ }
3851
+
3852
+ template <>
3853
+ void gemv<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3854
+ ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
3855
+ }
3856
+
3857
+ // gemm
3858
+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
3859
+ void gemm(int, float *, size_t, const void *, const void *, int, int);
3860
+
3861
+ template <> void gemm<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3862
+ ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
3863
+ }
3864
+
3865
+ template <> void gemm<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3866
+ ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
3867
+ }
3868
+
3869
+ template <> void gemm<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3870
+ ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
3871
+ }
3872
+
3873
+ template <>
3874
+ void gemm<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3875
+ ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
3876
+ }
3877
+
3878
+ class tensor_traits_base : public ggml::cpu::tensor_traits {
3879
+ public:
3880
+ virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
3881
+ };
3882
+
3883
+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_traits : public tensor_traits_base {
3884
+
3885
+ bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
3886
+ // not realy a GGML_TYPE_Q8_0 but same size.
3887
+ switch (op->op) {
3888
+ case GGML_OP_MUL_MAT:
3889
+ size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1]));
3890
+ return true;
3891
+ case GGML_OP_MUL_MAT_ID:
3892
+ size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1]));
3893
+ size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
3894
+ size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
3895
+ return true;
3896
+ default:
3897
+ // GGML_ABORT("fatal error");
3539
3898
  break;
3899
+ }
3900
+ return false;
3901
+ }
3902
+
3903
+ bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
3904
+ switch (op->op) {
3905
+ case GGML_OP_MUL_MAT:
3906
+ forward_mul_mat(params, op);
3907
+ return true;
3908
+ case GGML_OP_MUL_MAT_ID:
3909
+ forward_mul_mat_id(params, op);
3910
+ return true;
3540
3911
  default:
3541
- GGML_ABORT("Unsupported type");
3912
+ // GGML_ABORT("fatal error");
3913
+ break;
3914
+ }
3915
+ return false;
3542
3916
  }
3543
- }
3544
3917
 
3545
- enum ggml_type ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur) {
3918
+ void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) {
3919
+ const ggml_tensor * src0 = op->src[0];
3920
+ const ggml_tensor * src1 = op->src[1];
3921
+ ggml_tensor * dst = op;
3922
+
3923
+ GGML_TENSOR_BINARY_OP_LOCALS
3924
+
3925
+ const int ith = params->ith;
3926
+ const int nth = params->nth;
3927
+
3928
+ GGML_ASSERT(ne0 == ne01);
3929
+ GGML_ASSERT(ne1 == ne11);
3930
+ GGML_ASSERT(ne2 == ne12);
3931
+ GGML_ASSERT(ne3 == ne13);
3932
+
3933
+ // dst cannot be transposed or permuted
3934
+ GGML_ASSERT(nb0 == sizeof(float));
3935
+ GGML_ASSERT(nb0 <= nb1);
3936
+ GGML_ASSERT(nb1 <= nb2);
3937
+ GGML_ASSERT(nb2 <= nb3);
3938
+
3939
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
3940
+
3941
+ GGML_ASSERT(ggml_n_dims(op->src[0]) == 2);
3942
+ // GGML_ASSERT(ggml_n_dims(op->src[1]) == 2);
3943
+
3944
+ char * wdata = static_cast<char *>(params->wdata);
3945
+ const size_t nbw1 = ggml_row_size(GGML_TYPE_Q8_0, ne10);
3946
+
3947
+ assert(params->wsize >= nbw1 * ne11);
3948
+
3949
+ const ggml_from_float_t from_float = ggml_get_type_traits_cpu(GGML_TYPE_Q8_0)->from_float;
3950
+
3951
+ int64_t i11_processed = 0;
3952
+ for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
3953
+ quantize_mat_q8_0((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10,
3954
+ INTER_SIZE);
3955
+ }
3956
+ i11_processed = ne11 - ne11 % 4;
3957
+ for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
3958
+ from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
3959
+ }
3960
+
3961
+ ggml_barrier(params->threadpool);
3962
+
3963
+ const void * src1_wdata = params->wdata;
3964
+ const size_t src1_col_stride = ggml_row_size(GGML_TYPE_Q8_0, ne10);
3965
+ int64_t src0_start = (ith * ne01) / nth;
3966
+ int64_t src0_end = ((ith + 1) * ne01) / nth;
3967
+ src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
3968
+ src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
3969
+ if (src0_start >= src0_end) {
3970
+ return;
3971
+ }
3972
+
3973
+ // If there are more than three rows in src1, use gemm; otherwise, use gemv.
3974
+ if (ne11 > 3) {
3975
+ gemm<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data) + src0_start, ne01,
3976
+ (const char *) src0->data + src0_start * nb01,
3977
+ (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
3978
+ }
3979
+ for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
3980
+ gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
3981
+ (const char *) src0->data + src0_start * nb01,
3982
+ (const char *) src1_wdata + (src1_col_stride * iter), 1,
3983
+ src0_end - src0_start);
3984
+ }
3985
+ }
3986
+
3987
+ void forward_mul_mat_id(ggml_compute_params * params, ggml_tensor * op) {
3988
+ const ggml_tensor * src0 = op->src[0];
3989
+ const ggml_tensor * src1 = op->src[1];
3990
+ const ggml_tensor * ids = op->src[2];
3991
+ ggml_tensor * dst = op;
3992
+
3993
+ GGML_TENSOR_BINARY_OP_LOCALS
3994
+
3995
+ const int ith = params->ith;
3996
+ const int nth = params->nth;
3997
+
3998
+ const ggml_from_float_t from_float = ggml_get_type_traits_cpu(GGML_TYPE_Q8_0)->from_float;
3999
+
4000
+ // we don't support permuted src0 or src1
4001
+ GGML_ASSERT(nb00 == ggml_type_size(src0->type));
4002
+ GGML_ASSERT(nb10 == ggml_type_size(src1->type));
4003
+
4004
+ // dst cannot be transposed or permuted
4005
+ GGML_ASSERT(nb0 == sizeof(float));
4006
+ GGML_ASSERT(nb0 <= nb1);
4007
+ GGML_ASSERT(nb1 <= nb2);
4008
+ GGML_ASSERT(nb2 <= nb3);
4009
+
4010
+ GGML_ASSERT(ne03 == 1);
4011
+ GGML_ASSERT(ne13 == 1);
4012
+ GGML_ASSERT(ne3 == 1);
4013
+
4014
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
4015
+
4016
+ // row groups
4017
+ const int n_ids = ids->ne[0]; // n_expert_used
4018
+ const int n_as = ne02; // n_expert
4019
+
4020
+ const size_t nbw1 = ggml_row_size(GGML_TYPE_Q8_0, ne10);
4021
+ const size_t nbw2 = nbw1*ne11;
4022
+ const size_t nbw3 = nbw2*ne12;
4023
+
4024
+ struct mmid_row_mapping {
4025
+ int32_t i1;
4026
+ int32_t i2;
4027
+ };
4028
+
4029
+ GGML_ASSERT(params->wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
4030
+ n_as * ne12 * sizeof(mmid_row_mapping)));
4031
+
4032
+ auto wdata = (char *) params->wdata;
4033
+ auto wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t));
4034
+ int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
4035
+ struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
4036
+
4037
+ // src1: float32 => block_q8_0
4038
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
4039
+ for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
4040
+ from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
4041
+ (void *) (wdata + i12 * nbw2 + i11 * nbw1),
4042
+ ne10);
4043
+ }
4044
+ }
4045
+
4046
+ #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ne12 + (i1)]
4047
+
4048
+ if (ith == 0) {
4049
+ // initialize matrix_row_counts
4050
+ memset(matrix_row_counts, 0, n_as * sizeof(int64_t));
4051
+
4052
+ // group rows by src0 matrix
4053
+ for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
4054
+ for (int32_t id = 0; id < n_ids; ++id) {
4055
+ const int32_t i02 =
4056
+ *(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
4057
+
4058
+ GGML_ASSERT(i02 >= 0 && i02 < n_as);
4059
+
4060
+ MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 };
4061
+ matrix_row_counts[i02] += 1;
4062
+ }
4063
+ }
4064
+ }
4065
+
4066
+ ggml_barrier(params->threadpool);
4067
+
4068
+ // compute each matrix multiplication in sequence
4069
+ for (int cur_a = 0; cur_a < n_as; ++cur_a) {
4070
+ const int64_t cne1 = matrix_row_counts[cur_a];
4071
+
4072
+ if (cne1 == 0) {
4073
+ continue;
4074
+ }
4075
+
4076
+ auto src0_cur = (const char *) src0->data + cur_a*nb02;
4077
+
4078
+ //const int64_t nr0 = ne01; // src0 rows
4079
+ const int64_t nr1 = cne1; // src1 rows
4080
+
4081
+ int64_t src0_cur_start = (ith * ne01) / nth;
4082
+ int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
4083
+ src0_cur_start =
4084
+ (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
4085
+ src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
4086
+
4087
+ if (src0_cur_start >= src0_cur_end) return;
4088
+
4089
+ for (int ir1 = 0; ir1 < nr1; ir1++) {
4090
+ struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
4091
+ const int id = row_mapping.i1; // selected expert index
4092
+
4093
+ const int64_t i11 = id % ne11;
4094
+ const int64_t i12 = row_mapping.i2; // row index in src1
4095
+
4096
+ const int64_t i1 = id; // selected expert index
4097
+ const int64_t i2 = i12; // row
4098
+
4099
+ auto src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
4100
+
4101
+ gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(
4102
+ ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start,
4103
+ ne01, src0_cur + src0_cur_start * nb01,
4104
+ src1_col, 1, src0_cur_end - src0_cur_start);
4105
+ }
4106
+ }
4107
+ #undef MMID_MATRIX_ROW
4108
+ }
4109
+
4110
+ int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
4111
+ GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type),
4112
+ (int) NB_COLS, (int) INTER_SIZE);
4113
+ return ggml::cpu::aarch64::repack<BLOC_TYPE, INTER_SIZE, NB_COLS>(t, data, data_size);
4114
+ }
4115
+ };
4116
+
4117
+ // instance for Q4
4118
+ static const tensor_traits<block_q4_0, 4, 4> q4_0_4x4_q8_0;
4119
+ static const tensor_traits<block_q4_0, 8, 4> q4_0_4x8_q8_0;
4120
+ static const tensor_traits<block_q4_0, 8, 8> q4_0_8x8_q8_0;
4121
+
4122
+ // instance for IQ4
4123
+ static const tensor_traits<block_iq4_nl, 4, 4> iq4_nl_4x4_q8_0;
4124
+
4125
+ } // namespace ggml::cpu::aarch64
4126
+
4127
+ static const ggml::cpu::tensor_traits * ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur) {
3546
4128
  if (cur->type == GGML_TYPE_Q4_0) {
3547
- // TODO: enable for AVX2 - currently disabled due to bad gemv performance
3548
- if (/* ggml_cpu_has_avx2() || */ (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) {
3549
- return GGML_TYPE_Q4_0_8_8;
4129
+ if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) {
4130
+ if (cur->ne[1] % 8 == 0) {
4131
+ return &ggml::cpu::aarch64::q4_0_8x8_q8_0;
4132
+ }
3550
4133
  }
3551
4134
  if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
3552
- return GGML_TYPE_Q4_0_4_8;
4135
+ if (cur->ne[1] % 4 == 0) {
4136
+ return &ggml::cpu::aarch64::q4_0_4x8_q8_0;
4137
+ }
4138
+ }
4139
+ if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
4140
+ if (cur->ne[1] % 4 == 0) {
4141
+ return &ggml::cpu::aarch64::q4_0_4x4_q8_0;
4142
+ }
3553
4143
  }
3554
- if (ggml_cpu_has_neon()) {
3555
- return GGML_TYPE_Q4_0_4_4;
4144
+ } else if (cur->type == GGML_TYPE_IQ4_NL) {
4145
+ if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
4146
+ if (cur->ne[1] % 4 == 0) {
4147
+ return &ggml::cpu::aarch64::iq4_nl_4x4_q8_0;
4148
+ }
4149
+ }
4150
+ }
4151
+
4152
+ return nullptr;
4153
+ }
4154
+
4155
+ static void ggml_backend_cpu_aarch64_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
4156
+ tensor->extra = (void *) const_cast<ggml::cpu::tensor_traits *>(ggml_aarch64_get_optimal_repack_type(tensor));
4157
+
4158
+ GGML_UNUSED(buffer);
4159
+ }
4160
+
4161
+ static void ggml_backend_cpu_aarch64_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
4162
+ const void * data, size_t offset, size_t size) {
4163
+ GGML_ASSERT(offset == 0);
4164
+ GGML_ASSERT(size == ggml_nbytes(tensor));
4165
+
4166
+ auto tensor_traits = (ggml::cpu::aarch64::tensor_traits_base *) tensor->extra;
4167
+ auto OK = tensor_traits->repack(tensor, data, size);
4168
+
4169
+ GGML_ASSERT(OK == 0);
4170
+ GGML_UNUSED(buffer);
4171
+ }
4172
+
4173
+ static const char * ggml_backend_cpu_aarch64_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
4174
+ return "CPU_AARCH64";
4175
+
4176
+ GGML_UNUSED(buft);
4177
+ }
4178
+
4179
+ static ggml_backend_buffer_t ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
4180
+ ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
4181
+
4182
+ if (buffer == nullptr) {
4183
+ return nullptr;
4184
+ }
4185
+
4186
+ buffer->buft = buft;
4187
+ buffer->iface.init_tensor = ggml_backend_cpu_aarch64_buffer_init_tensor;
4188
+ buffer->iface.set_tensor = ggml_backend_cpu_aarch64_buffer_set_tensor;
4189
+ return buffer;
4190
+ }
4191
+
4192
+ static size_t ggml_backend_cpu_aarch64_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
4193
+ return TENSOR_ALIGNMENT;
4194
+
4195
+ GGML_UNUSED(buft);
4196
+ }
4197
+
4198
+ namespace ggml::cpu::aarch64 {
4199
+ class extra_buffer_type : ggml::cpu::extra_buffer_type {
4200
+ bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
4201
+ if ( op->op == GGML_OP_MUL_MAT &&
4202
+ op->src[0]->buffer &&
4203
+ (ggml_n_dims(op->src[0]) == 2) &&
4204
+ op->src[0]->buffer->buft == ggml_backend_cpu_aarch64_buffer_type() &&
4205
+ ggml_aarch64_get_optimal_repack_type(op->src[0])
4206
+ ) {
4207
+ if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
4208
+ return false;
4209
+ }
4210
+ if (op->src[1]->type == GGML_TYPE_F32) {
4211
+ return true;
4212
+ }
4213
+ //if (op->src[1]->type == GGML_TYPE_Q8_0) {
4214
+ // return true;
4215
+ //}
4216
+ // may be possible if Q8_0 packed...
4217
+ } else if (op->op == GGML_OP_MUL_MAT_ID
4218
+ && op->src[0]->buffer
4219
+ && (ggml_n_dims(op->src[0]) == 3)
4220
+ && op->src[0]->buffer->buft == ggml_backend_cpu_aarch64_buffer_type()
4221
+ && ggml_aarch64_get_optimal_repack_type(op->src[0])
4222
+ ) {
4223
+ if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
4224
+ return false;
4225
+ }
4226
+ if (op->src[1]->type == GGML_TYPE_F32) {
4227
+ return true;
4228
+ }
4229
+ //if (op->src[1]->type == GGML_TYPE_Q8_0) {
4230
+ // return true;
4231
+ //}
3556
4232
  }
4233
+ return false;
3557
4234
  }
3558
4235
 
3559
- return cur->type;
4236
+ ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
4237
+ if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) {
4238
+ if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_aarch64_buffer_type()) {
4239
+ return (ggml::cpu::tensor_traits *) op->src[0]->extra;
4240
+ }
4241
+ }
4242
+ return nullptr;
4243
+ }
4244
+ };
4245
+ } // namespace ggml::cpu::aarch64
4246
+
4247
+ ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void) {
4248
+ static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_aarch64 = {
4249
+ /* .iface = */ {
4250
+ /* .get_name = */ ggml_backend_cpu_aarch64_buffer_type_get_name,
4251
+ /* .alloc_buffer = */ ggml_backend_cpu_aarch64_buffer_type_alloc_buffer,
4252
+ /* .get_alignment = */ ggml_backend_cpu_aarch64_buffer_type_get_alignment,
4253
+ /* .get_max_size = */ nullptr, // defaults to SIZE_MAX
4254
+ /* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
4255
+ /* .is_host = */ nullptr,
4256
+ },
4257
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
4258
+ /* .context = */ new ggml::cpu::aarch64::extra_buffer_type(),
4259
+ };
4260
+
4261
+ return &ggml_backend_cpu_buffer_type_aarch64;
3560
4262
  }