cui-llama.rn 1.7.4 → 1.7.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (276) hide show
  1. package/README.md +217 -17
  2. package/android/src/main/CMakeLists.txt +34 -15
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +79 -5
  4. package/android/src/main/java/com/rnllama/RNLlama.java +237 -0
  5. package/android/src/main/jni.cpp +213 -14
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  14. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +35 -0
  15. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +34 -0
  16. package/cpp/README.md +1 -1
  17. package/cpp/chat-parser.cpp +385 -0
  18. package/cpp/chat-parser.h +120 -0
  19. package/cpp/chat.cpp +726 -596
  20. package/cpp/chat.h +71 -6
  21. package/cpp/common.cpp +56 -38
  22. package/cpp/common.h +9 -3
  23. package/cpp/ggml-backend-reg.cpp +5 -0
  24. package/cpp/ggml-backend.cpp +10 -2
  25. package/cpp/ggml-common.h +4 -0
  26. package/cpp/ggml-cpu/amx/amx.cpp +1 -1
  27. package/cpp/ggml-cpu/amx/mmq.cpp +11 -10
  28. package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  29. package/cpp/ggml-cpu/arch/arm/quants.c +4114 -0
  30. package/cpp/ggml-cpu/arch/arm/repack.cpp +2163 -0
  31. package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
  32. package/cpp/ggml-cpu/arch/x86/quants.c +4311 -0
  33. package/cpp/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  34. package/cpp/ggml-cpu/arch-fallback.h +184 -0
  35. package/cpp/ggml-cpu/common.h +4 -3
  36. package/cpp/ggml-cpu/ggml-cpu-impl.h +21 -16
  37. package/cpp/ggml-cpu/ggml-cpu.c +123 -104
  38. package/cpp/ggml-cpu/ggml-cpu.cpp +11 -8
  39. package/cpp/ggml-cpu/ops.cpp +330 -148
  40. package/cpp/ggml-cpu/ops.h +1 -0
  41. package/cpp/ggml-cpu/quants.c +1158 -0
  42. package/cpp/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  43. package/cpp/ggml-cpu/repack.cpp +1571 -0
  44. package/cpp/ggml-cpu/repack.h +98 -0
  45. package/cpp/ggml-cpu/simd-mappings.h +330 -38
  46. package/cpp/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  47. package/cpp/ggml-cpu/vec.cpp +87 -18
  48. package/cpp/ggml-cpu/vec.h +249 -94
  49. package/cpp/ggml-cpu.h +1 -0
  50. package/cpp/ggml-impl.h +63 -183
  51. package/cpp/ggml-llama-sim.metallib +0 -0
  52. package/cpp/ggml-llama.metallib +0 -0
  53. package/cpp/ggml-metal.m +152 -45
  54. package/cpp/ggml-quants.c +0 -2
  55. package/cpp/ggml.c +61 -21
  56. package/cpp/ggml.h +22 -3
  57. package/cpp/gguf.cpp +24 -3
  58. package/cpp/json-partial.cpp +256 -0
  59. package/cpp/json-partial.h +38 -0
  60. package/cpp/json-schema-to-grammar.cpp +5 -47
  61. package/cpp/json-schema-to-grammar.h +4 -4
  62. package/cpp/llama-arch.cpp +153 -3
  63. package/cpp/llama-arch.h +27 -1
  64. package/cpp/llama-batch.cpp +741 -272
  65. package/cpp/llama-batch.h +112 -54
  66. package/cpp/llama-chat.cpp +30 -8
  67. package/cpp/llama-chat.h +1 -0
  68. package/cpp/llama-context.cpp +524 -339
  69. package/cpp/llama-context.h +38 -17
  70. package/cpp/llama-cparams.cpp +4 -0
  71. package/cpp/llama-cparams.h +2 -0
  72. package/cpp/llama-grammar.cpp +12 -2
  73. package/cpp/llama-graph.cpp +431 -356
  74. package/cpp/llama-graph.h +126 -58
  75. package/cpp/llama-hparams.cpp +10 -2
  76. package/cpp/llama-hparams.h +19 -2
  77. package/cpp/llama-kv-cache-unified-iswa.cpp +279 -0
  78. package/cpp/llama-kv-cache-unified-iswa.h +128 -0
  79. package/cpp/llama-kv-cache-unified.cpp +1841 -0
  80. package/cpp/llama-kv-cache-unified.h +303 -0
  81. package/cpp/llama-kv-cells.h +439 -0
  82. package/cpp/llama-memory-hybrid.cpp +246 -0
  83. package/cpp/llama-memory-hybrid.h +138 -0
  84. package/cpp/llama-memory-recurrent.cpp +1112 -0
  85. package/cpp/llama-memory-recurrent.h +183 -0
  86. package/cpp/llama-memory.cpp +41 -0
  87. package/cpp/llama-memory.h +86 -5
  88. package/cpp/llama-mmap.cpp +1 -1
  89. package/cpp/llama-model-loader.cpp +42 -17
  90. package/cpp/llama-model-saver.cpp +1 -0
  91. package/cpp/llama-model.cpp +1639 -513
  92. package/cpp/llama-model.h +26 -0
  93. package/cpp/llama-sampling.cpp +2 -2
  94. package/cpp/llama-vocab.cpp +65 -28
  95. package/cpp/llama-vocab.h +1 -0
  96. package/cpp/llama.cpp +11 -7
  97. package/cpp/llama.h +150 -42
  98. package/cpp/minja/chat-template.hpp +1 -1
  99. package/cpp/minja/minja.hpp +1 -1
  100. package/cpp/{json.hpp → nlohmann/json.hpp} +3027 -2267
  101. package/cpp/nlohmann/json_fwd.hpp +187 -0
  102. package/cpp/regex-partial.cpp +204 -0
  103. package/cpp/regex-partial.h +56 -0
  104. package/cpp/rn-llama.cpp +646 -35
  105. package/cpp/rn-llama.h +32 -1
  106. package/cpp/rn-tts.h +39 -0
  107. package/cpp/sampling.cpp +7 -8
  108. package/cpp/tools/mtmd/clip-impl.h +5 -0
  109. package/cpp/tools/mtmd/clip.cpp +572 -436
  110. package/cpp/tools/mtmd/clip.h +14 -4
  111. package/cpp/tools/mtmd/mtmd-audio.cpp +0 -86
  112. package/cpp/tools/mtmd/mtmd-audio.h +2 -17
  113. package/cpp/tools/mtmd/mtmd-helper.cpp +175 -12
  114. package/cpp/tools/mtmd/mtmd-helper.h +91 -0
  115. package/cpp/tools/mtmd/mtmd.cpp +368 -248
  116. package/cpp/tools/mtmd/mtmd.h +6 -70
  117. package/cpp/unicode.cpp +5 -0
  118. package/ios/CMakeLists.txt +26 -6
  119. package/ios/RNLlama.h +1 -1
  120. package/ios/RNLlama.mm +153 -3
  121. package/ios/RNLlamaContext.h +9 -1
  122. package/ios/RNLlamaContext.mm +112 -9
  123. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat-parser.h +120 -0
  124. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +71 -6
  125. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +9 -3
  126. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-common.h +4 -0
  127. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +1 -0
  128. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +63 -183
  129. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +22 -3
  130. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-partial.h +38 -0
  131. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  132. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +27 -1
  133. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +112 -54
  134. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  135. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +38 -17
  136. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +2 -0
  137. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +126 -58
  138. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +19 -2
  139. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  140. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  141. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  142. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  143. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  144. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +86 -5
  145. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +26 -0
  146. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +1 -0
  147. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +150 -42
  148. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  149. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +1 -1
  150. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/{json.hpp → nlohmann/json.hpp} +3027 -2267
  151. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  152. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/regex-partial.h +56 -0
  153. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +32 -1
  154. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-tts.h +39 -0
  155. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  156. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  157. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat-parser.h +120 -0
  158. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +71 -6
  159. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +9 -3
  160. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +4 -0
  161. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +1 -0
  162. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +63 -183
  163. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +22 -3
  164. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-partial.h +38 -0
  165. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  166. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +27 -1
  167. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +112 -54
  168. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  169. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +38 -17
  170. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +2 -0
  171. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +126 -58
  172. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +19 -2
  173. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  174. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  175. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  176. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  177. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  178. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +86 -5
  179. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +26 -0
  180. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +1 -0
  181. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +150 -42
  182. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  183. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +1 -1
  184. package/ios/rnllama.xcframework/{tvos-arm64/rnllama.framework/Headers → ios-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann}/json.hpp +3027 -2267
  185. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  186. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/regex-partial.h +56 -0
  187. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +32 -1
  188. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-tts.h +39 -0
  189. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  190. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  191. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat-parser.h +120 -0
  192. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +71 -6
  193. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +9 -3
  194. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-common.h +4 -0
  195. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +1 -0
  196. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +63 -183
  197. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +22 -3
  198. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-partial.h +38 -0
  199. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  200. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +27 -1
  201. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +112 -54
  202. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  203. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +38 -17
  204. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +2 -0
  205. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +126 -58
  206. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +19 -2
  207. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  208. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  209. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  210. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  211. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  212. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +86 -5
  213. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +26 -0
  214. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +1 -0
  215. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +150 -42
  216. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  217. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +1 -1
  218. package/ios/rnllama.xcframework/{ios-arm64_x86_64-simulator/rnllama.framework/Headers → tvos-arm64/rnllama.framework/Headers/nlohmann}/json.hpp +3027 -2267
  219. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  220. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/regex-partial.h +56 -0
  221. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +32 -1
  222. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-tts.h +39 -0
  223. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  224. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  225. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat-parser.h +120 -0
  226. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +71 -6
  227. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +9 -3
  228. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +4 -0
  229. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +1 -0
  230. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +63 -183
  231. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +22 -3
  232. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-partial.h +38 -0
  233. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  234. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +27 -1
  235. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +112 -54
  236. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  237. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +38 -17
  238. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +2 -0
  239. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +126 -58
  240. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +19 -2
  241. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  242. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  243. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  244. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  245. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  246. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +86 -5
  247. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +26 -0
  248. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +1 -0
  249. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +150 -42
  250. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  251. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +1 -1
  252. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json.hpp +25526 -0
  253. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  254. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/regex-partial.h +56 -0
  255. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +32 -1
  256. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-tts.h +39 -0
  257. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  258. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  259. package/jest/mock.js +24 -0
  260. package/package.json +1 -1
  261. package/src/NativeRNLlama.ts +46 -2
  262. package/src/index.ts +105 -1
  263. package/cpp/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  264. package/cpp/ggml-cpu/ggml-cpu-quants.c +0 -13326
  265. package/cpp/ggml-cpu/sgemm.cpp +0 -3544
  266. package/cpp/ggml-cpu/sgemm.h +0 -14
  267. package/cpp/llama-kv-cache.cpp +0 -2827
  268. package/cpp/llama-kv-cache.h +0 -515
  269. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  270. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  271. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  272. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +0 -24766
  273. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  274. /package/cpp/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  275. /package/cpp/tools/mtmd/{miniaudio.h → miniaudio/miniaudio.h} +0 -0
  276. /package/cpp/tools/mtmd/{stb_image.h → stb/stb_image.h} +0 -0
@@ -0,0 +1,4114 @@
1
+ #define LM_GGML_COMMON_IMPL_C
2
+ #include "ggml-common.h"
3
+ #include "ggml-quants.h"
4
+ #include "ggml-impl.h"
5
+ #include "ggml-cpu.h"
6
+ #include "simd-mappings.h"
7
+
8
+ #include "../../quants.h"
9
+ #include "../../ggml-cpu-impl.h"
10
+
11
+ #include <math.h>
12
+ #include <string.h>
13
+ #include <assert.h>
14
+ #include <float.h>
15
+ #include <stdlib.h> // for qsort
16
+ #include <stdio.h> // for LM_GGML_ASSERT
17
+
18
+ #define GROUP_MAX_EPS 1e-15f
19
+ #define GROUP_MAX_EPS_IQ3_XXS 1e-8f
20
+ #define GROUP_MAX_EPS_IQ2_S 1e-8f
21
+ #define GROUP_MAX_EPS_IQ1_M 1e-7f
22
+ #define GROUP_MAX_EPS_IQ1_S 1e-12f
23
+
24
+ #define UNUSED LM_GGML_UNUSED
25
+
26
+ #if defined(__ARM_NEON)
27
+ #define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
28
+ #define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
29
+ #define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
30
+ #define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
31
+ #define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
32
+ #define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
33
+ #define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
34
+ #define B8(c,s ) B7(c,s, c), B7(c,s, s)
35
+
36
+ // precomputed tables for expanding 8bits to 8 bytes:
37
+ static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
38
+ static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
39
+ #endif
40
+
41
+ void quantize_row_q8_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t k) {
42
+ assert(QK8_0 == 32);
43
+ assert(k % QK8_0 == 0);
44
+ const int nb = k / QK8_0;
45
+
46
+ block_q8_0 * LM_GGML_RESTRICT y = vy;
47
+
48
+ #if defined(__ARM_NEON)
49
+ for (int i = 0; i < nb; i++) {
50
+ float32x4_t srcv [8];
51
+ float32x4_t asrcv[8];
52
+ float32x4_t amaxv[8];
53
+
54
+ for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
55
+ for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
56
+
57
+ for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
58
+ for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
59
+ for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
60
+
61
+ const float amax = vmaxvq_f32(amaxv[0]);
62
+
63
+ const float d = amax / ((1 << 7) - 1);
64
+ const float id = d ? 1.0f/d : 0.0f;
65
+
66
+ y[i].d = LM_GGML_CPU_FP32_TO_FP16(d);
67
+
68
+ for (int j = 0; j < 8; j++) {
69
+ const float32x4_t v = vmulq_n_f32(srcv[j], id);
70
+ const int32x4_t vi = vcvtnq_s32_f32(v);
71
+
72
+ y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
73
+ y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
74
+ y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
75
+ y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
76
+ }
77
+ }
78
+ #else
79
+ LM_GGML_UNUSED(nb);
80
+ // scalar
81
+ quantize_row_q8_0_ref(x, y, k);
82
+ #endif
83
+ }
84
+
85
+ void quantize_row_q8_1(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t k) {
86
+ assert(k % QK8_1 == 0);
87
+ const int nb = k / QK8_1;
88
+
89
+ block_q8_1 * LM_GGML_RESTRICT y = vy;
90
+ #if defined(__ARM_NEON)
91
+ for (int i = 0; i < nb; i++) {
92
+ float32x4_t srcv [8];
93
+ float32x4_t asrcv[8];
94
+ float32x4_t amaxv[8];
95
+
96
+ for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
97
+ for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
98
+
99
+ for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
100
+ for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
101
+ for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
102
+
103
+ const float amax = vmaxvq_f32(amaxv[0]);
104
+
105
+ const float d = amax / ((1 << 7) - 1);
106
+ const float id = d ? 1.0f/d : 0.0f;
107
+
108
+ y[i].d = LM_GGML_CPU_FP32_TO_FP16(d);
109
+
110
+ int32x4_t accv = vdupq_n_s32(0);
111
+
112
+ for (int j = 0; j < 8; j++) {
113
+ const float32x4_t v = vmulq_n_f32(srcv[j], id);
114
+ const int32x4_t vi = vcvtnq_s32_f32(v);
115
+
116
+ y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
117
+ y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
118
+ y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
119
+ y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
120
+
121
+ accv = vaddq_s32(accv, vi);
122
+ }
123
+
124
+ y[i].s = LM_GGML_CPU_FP32_TO_FP16(d * vaddvq_s32(accv));
125
+ }
126
+ #else
127
+ LM_GGML_UNUSED(nb);
128
+ // scalar
129
+ quantize_row_q8_1_ref(x, y, k);
130
+ #endif
131
+ }
132
+
133
+ // placeholder implementation for Apple targets
134
+ void quantize_row_q8_K(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k) {
135
+ quantize_row_q8_K_ref(x, y, k);
136
+ }
137
+
138
+ //===================================== Dot products =================================
139
+
140
+ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
141
+ const int qk = QK8_0;
142
+ const int nb = n / qk;
143
+
144
+ assert(n % qk == 0);
145
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
146
+ assert((nrc == 2) || (nrc == 1));
147
+ #else
148
+ assert(nrc == 1);
149
+ #endif
150
+ UNUSED(nrc);
151
+ UNUSED(bx);
152
+ UNUSED(by);
153
+ UNUSED(bs);
154
+
155
+ const block_q4_0 * LM_GGML_RESTRICT x = vx;
156
+ const block_q8_0 * LM_GGML_RESTRICT y = vy;
157
+
158
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
159
+ if (nrc == 2) {
160
+ const block_q4_0 * LM_GGML_RESTRICT vx0 = vx;
161
+ const block_q4_0 * LM_GGML_RESTRICT vx1 = (const block_q4_0 *) ((const uint8_t*)vx + bx);
162
+ const block_q8_0 * LM_GGML_RESTRICT vy0 = vy;
163
+ const block_q8_0 * LM_GGML_RESTRICT vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);
164
+
165
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
166
+
167
+ for (int i = 0; i < nb; i++) {
168
+ const block_q4_0 * LM_GGML_RESTRICT b_x0 = &vx0[i];
169
+ const block_q4_0 * LM_GGML_RESTRICT b_x1 = &vx1[i];
170
+ const block_q8_0 * LM_GGML_RESTRICT b_y0 = &vy0[i];
171
+ const block_q8_0 * LM_GGML_RESTRICT b_y1 = &vy1[i];
172
+
173
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
174
+ const int8x16_t s8b = vdupq_n_s8(0x8);
175
+
176
+ const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
177
+ const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
178
+
179
+ // 4-bit -> 8-bit
180
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
181
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
182
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
183
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
184
+
185
+ // sub 8
186
+ const int8x16_t x0_l = vsubq_s8(v0_0l, s8b);
187
+ const int8x16_t x0_h = vsubq_s8(v0_0h, s8b);
188
+ const int8x16_t x1_l = vsubq_s8(v0_1l, s8b);
189
+ const int8x16_t x1_h = vsubq_s8(v0_1h, s8b);
190
+
191
+ // load y
192
+ const int8x16_t y0_l = vld1q_s8(b_y0->qs);
193
+ const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
194
+ const int8x16_t y1_l = vld1q_s8(b_y1->qs);
195
+ const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
196
+
197
+ float32_t _scale[4] = {
198
+ LM_GGML_CPU_FP16_TO_FP32(b_x0->d)*LM_GGML_CPU_FP16_TO_FP32(b_y0->d),
199
+ LM_GGML_CPU_FP16_TO_FP32(b_x0->d)*LM_GGML_CPU_FP16_TO_FP32(b_y1->d),
200
+ LM_GGML_CPU_FP16_TO_FP32(b_x1->d)*LM_GGML_CPU_FP16_TO_FP32(b_y0->d),
201
+ LM_GGML_CPU_FP16_TO_FP32(b_x1->d)*LM_GGML_CPU_FP16_TO_FP32(b_y1->d)
202
+ };
203
+ float32x4_t scale = vld1q_f32(_scale);
204
+
205
+ int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
206
+ int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
207
+
208
+ int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
209
+ int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
210
+
211
+ int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
212
+ int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
213
+
214
+ int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
215
+ int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
216
+
217
+ sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
218
+ l1, r1)), l2, r2)), l3, r3))), scale);
219
+ }
220
+
221
+ float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
222
+ float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
223
+
224
+ vst1_f32(s, vget_low_f32 (sumv2));
225
+ vst1_f32(s + bs, vget_high_f32(sumv2));
226
+
227
+ return;
228
+ }
229
+ #endif
230
+
231
+ int ib = 0;
232
+ float sumf = 0;
233
+
234
+ #if defined(__ARM_FEATURE_SVE)
235
+ svfloat32_t sumv0 = svdup_n_f32(0.0f);
236
+ svfloat32_t sumv1 = svdup_n_f32(0.0f);
237
+
238
+ const int vector_length = lm_ggml_cpu_get_sve_cnt()*8;
239
+
240
+ // VLA Implementation using switch case
241
+ switch (vector_length) {
242
+ case 128:
243
+ {
244
+ // predicate for activating higher lanes for 4 float32 elements
245
+ const svbool_t ph4 = svptrue_pat_b32(SV_VL4);
246
+
247
+ for (; ib + 1 < nb; ib += 2) {
248
+ const block_q4_0 * LM_GGML_RESTRICT x0 = &x[ib + 0];
249
+ const block_q4_0 * LM_GGML_RESTRICT x1 = &x[ib + 1];
250
+ const block_q8_0 * LM_GGML_RESTRICT y0 = &y[ib + 0];
251
+ const block_q8_0 * LM_GGML_RESTRICT y1 = &y[ib + 1];
252
+
253
+ // load x
254
+ const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
255
+ const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
256
+
257
+ // 4-bit -> 8-bit
258
+ const svint8_t qx0l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx0r, 0x0F));
259
+ const svint8_t qx0h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx0r, 0x04));
260
+ const svint8_t qx1l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx1r, 0x0F));
261
+ const svint8_t qx1h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx1r, 0x04));
262
+
263
+ // sub 8
264
+ const svint8_t qx0ls = svsub_n_s8_x(svptrue_b8(), qx0h, 8);
265
+ const svint8_t qx0hs = svsub_n_s8_x(svptrue_b8(), qx0l, 8);
266
+ const svint8_t qx1ls = svsub_n_s8_x(svptrue_b8(), qx1h, 8);
267
+ const svint8_t qx1hs = svsub_n_s8_x(svptrue_b8(), qx1l, 8);
268
+
269
+ // load y
270
+ const svint8_t qy0h = svld1_s8(svptrue_b8(), y0->qs);
271
+ const svint8_t qy0l = svld1_s8(svptrue_b8(), y0->qs + 16);
272
+ const svint8_t qy1h = svld1_s8(svptrue_b8(), y1->qs);
273
+ const svint8_t qy1l = svld1_s8(svptrue_b8(), y1->qs + 16);
274
+
275
+ // dot product
276
+ sumv0 = svmla_n_f32_x(ph4, sumv0, svcvt_f32_s32_x(ph4, svadd_x(ph4,
277
+ svdot_s32(svdup_n_s32(0), qx0ls, qy0l),
278
+ svdot_s32(svdup_n_s32(0), qx0hs, qy0h))), LM_GGML_CPU_FP16_TO_FP32(x0->d)*LM_GGML_CPU_FP16_TO_FP32(y0->d));
279
+ sumv1 = svmla_n_f32_x(ph4, sumv1, svcvt_f32_s32_x(ph4, svadd_x(ph4,
280
+ svdot_s32(svdup_n_s32(0), qx1ls, qy1l),
281
+ svdot_s32(svdup_n_s32(0), qx1hs, qy1h))), LM_GGML_CPU_FP16_TO_FP32(x1->d)*LM_GGML_CPU_FP16_TO_FP32(y1->d));
282
+ }
283
+
284
+ sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
285
+ } break;
286
+ case 256:
287
+ {
288
+ // predicate for activating higher lanes for 16 int8 elements
289
+ const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
290
+ // predicate for activating lower lanes for 16 int8 elements
291
+ const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
292
+
293
+ for (; ib + 1 < nb; ib += 2) {
294
+ const block_q4_0 * LM_GGML_RESTRICT x0 = &x[ib + 0];
295
+ const block_q4_0 * LM_GGML_RESTRICT x1 = &x[ib + 1];
296
+ const block_q8_0 * LM_GGML_RESTRICT y0 = &y[ib + 0];
297
+ const block_q8_0 * LM_GGML_RESTRICT y1 = &y[ib + 1];
298
+
299
+ // load x
300
+ const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
301
+ const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
302
+
303
+ // 4-bit -> 8-bit
304
+ const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));
305
+ const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));
306
+
307
+ // sub 8
308
+ const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
309
+ const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
310
+
311
+ // load y
312
+ const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
313
+ const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
314
+
315
+ // dot product
316
+ sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),
317
+ svdot_s32(svdup_n_s32(0), qx0s, qy0)), LM_GGML_CPU_FP16_TO_FP32(x0->d)*LM_GGML_CPU_FP16_TO_FP32(y0->d));
318
+ sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),
319
+ svdot_s32(svdup_n_s32(0), qx1s, qy1)), LM_GGML_CPU_FP16_TO_FP32(x1->d)*LM_GGML_CPU_FP16_TO_FP32(y1->d));
320
+ }
321
+
322
+ sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
323
+ } break;
324
+ case 512:
325
+ {
326
+ // predicate for activating higher lanes for 32 int8 elements
327
+ const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
328
+
329
+ // predicate for activating higher lanes for 16 int8 elements
330
+ const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
331
+ // predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes
332
+ const svbool_t pl16 = svnot_b_z(ph32, ph16);
333
+
334
+ for (; ib + 1 < nb; ib += 2) {
335
+ const block_q4_0 * LM_GGML_RESTRICT x0 = &x[ib + 0];
336
+ const block_q4_0 * LM_GGML_RESTRICT x1 = &x[ib + 1];
337
+ const block_q8_0 * LM_GGML_RESTRICT y0 = &y[ib + 0];
338
+ const block_q8_0 * LM_GGML_RESTRICT y1 = &y[ib + 1];
339
+
340
+ // load x
341
+ const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs);
342
+ const svuint8_t qx1r = svld1rq_u8(ph32, x1->qs);
343
+
344
+ // 4-bit -> 8-bit
345
+ const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));
346
+ const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));
347
+
348
+ // sub 8
349
+ const svint8_t qx0s = svsub_n_s8_x(ph32, qx0, 8);
350
+ const svint8_t qx1s = svsub_n_s8_x(ph32, qx1, 8);
351
+
352
+ // load y
353
+ const svint8_t qy0 = svld1_s8(ph32, y0->qs);
354
+ const svint8_t qy1 = svld1_s8(ph32, y1->qs);
355
+
356
+ // dot product
357
+ sumv0 = svmla_n_f32_x(ph32, sumv0, svcvt_f32_s32_x(ph32,
358
+ svdot_s32(svdup_n_s32(0), qx0s, qy0)), LM_GGML_CPU_FP16_TO_FP32(x0->d)*LM_GGML_CPU_FP16_TO_FP32(y0->d));
359
+ sumv1 = svmla_n_f32_x(ph32, sumv1, svcvt_f32_s32_x(ph32,
360
+ svdot_s32(svdup_n_s32(0), qx1s, qy1)), LM_GGML_CPU_FP16_TO_FP32(x1->d)*LM_GGML_CPU_FP16_TO_FP32(y1->d));
361
+ }
362
+
363
+ sumf = svaddv_f32(ph32, svadd_f32_x(ph32, sumv0, sumv1));
364
+ } break;
365
+ default:
366
+ assert(false && "Unsupported vector length");
367
+ break;
368
+ }
369
+
370
+ #elif defined(__ARM_NEON)
371
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
372
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
373
+
374
+ for (; ib + 1 < nb; ib += 2) {
375
+ const block_q4_0 * LM_GGML_RESTRICT x0 = &x[ib + 0];
376
+ const block_q4_0 * LM_GGML_RESTRICT x1 = &x[ib + 1];
377
+ const block_q8_0 * LM_GGML_RESTRICT y0 = &y[ib + 0];
378
+ const block_q8_0 * LM_GGML_RESTRICT y1 = &y[ib + 1];
379
+
380
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
381
+ const int8x16_t s8b = vdupq_n_s8(0x8);
382
+
383
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
384
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
385
+
386
+ // 4-bit -> 8-bit
387
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
388
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
389
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
390
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
391
+
392
+ // sub 8
393
+ const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
394
+ const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
395
+ const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
396
+ const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
397
+
398
+ // load y
399
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
400
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
401
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
402
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
403
+
404
+ // dot product into int32x4_t
405
+ const int32x4_t p_0 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
406
+ const int32x4_t p_1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
407
+
408
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), LM_GGML_CPU_FP16_TO_FP32(x0->d)*LM_GGML_CPU_FP16_TO_FP32(y0->d));
409
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), LM_GGML_CPU_FP16_TO_FP32(x1->d)*LM_GGML_CPU_FP16_TO_FP32(y1->d));
410
+ }
411
+
412
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
413
+ #endif
414
+ for (; ib < nb; ++ib) {
415
+ int sumi0 = 0;
416
+ int sumi1 = 0;
417
+
418
+ for (int j = 0; j < qk/2; ++j) {
419
+ const int v0 = (x[ib].qs[j] & 0x0F) - 8;
420
+ const int v1 = (x[ib].qs[j] >> 4) - 8;
421
+
422
+ sumi0 += (v0 * y[ib].qs[j]);
423
+ sumi1 += (v1 * y[ib].qs[j + qk/2]);
424
+ }
425
+
426
+ int sumi = sumi0 + sumi1;
427
+ sumf += sumi*LM_GGML_CPU_FP16_TO_FP32(x[ib].d)*LM_GGML_CPU_FP16_TO_FP32(y[ib].d);
428
+ }
429
+
430
+ *s = sumf;
431
+ }
432
+
433
+ void lm_ggml_vec_dot_q4_1_q8_1(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
434
+ const int qk = QK8_1;
435
+ const int nb = n / qk;
436
+
437
+ assert(n % qk == 0);
438
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
439
+ assert((nrc == 2) || (nrc == 1));
440
+ #else
441
+ assert(nrc == 1);
442
+ #endif
443
+ UNUSED(nrc);
444
+ UNUSED(bx);
445
+ UNUSED(by);
446
+ UNUSED(bs);
447
+
448
+ const block_q4_1 * LM_GGML_RESTRICT x = vx;
449
+ const block_q8_1 * LM_GGML_RESTRICT y = vy;
450
+
451
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
452
+ if (nrc == 2) {
453
+ const block_q4_1 * LM_GGML_RESTRICT vx0 = vx;
454
+ const block_q4_1 * LM_GGML_RESTRICT vx1 = (const block_q4_1 *) ((const uint8_t*)vx + bx);
455
+ const block_q8_1 * LM_GGML_RESTRICT vy0 = vy;
456
+ const block_q8_1 * LM_GGML_RESTRICT vy1 = (const block_q8_1 *) ((const uint8_t*)vy + by);
457
+
458
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
459
+ float32x4_t summs0 = vdupq_n_f32(0.0f);
460
+
461
+ for (int i = 0; i < nb; i++) {
462
+ const block_q4_1 * LM_GGML_RESTRICT b_x0 = &vx0[i];
463
+ const block_q4_1 * LM_GGML_RESTRICT b_x1 = &vx1[i];
464
+ const block_q8_1 * LM_GGML_RESTRICT b_y0 = &vy0[i];
465
+ const block_q8_1 * LM_GGML_RESTRICT b_y1 = &vy1[i];
466
+
467
+ float32_t summs_t[4] = {
468
+ LM_GGML_CPU_FP16_TO_FP32(b_x0->m) * LM_GGML_CPU_FP16_TO_FP32(b_y0->s),
469
+ LM_GGML_CPU_FP16_TO_FP32(b_x1->m) * LM_GGML_CPU_FP16_TO_FP32(b_y0->s),
470
+ LM_GGML_CPU_FP16_TO_FP32(b_x0->m) * LM_GGML_CPU_FP16_TO_FP32(b_y1->s),
471
+ LM_GGML_CPU_FP16_TO_FP32(b_x1->m) * LM_GGML_CPU_FP16_TO_FP32(b_y1->s)
472
+ };
473
+ summs0 = vaddq_f32(summs0, vld1q_f32(summs_t));
474
+
475
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
476
+
477
+ const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
478
+ const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
479
+
480
+ // 4-bit -> 8-bit
481
+ const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
482
+ const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
483
+ const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
484
+ const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
485
+
486
+ // load y
487
+ const int8x16_t y0_l = vld1q_s8(b_y0->qs);
488
+ const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
489
+ const int8x16_t y1_l = vld1q_s8(b_y1->qs);
490
+ const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
491
+
492
+ // mmla into int32x4_t
493
+ float32_t _scale[4] = {
494
+ LM_GGML_CPU_FP16_TO_FP32(b_x0->d)*LM_GGML_CPU_FP16_TO_FP32(b_y0->d),
495
+ LM_GGML_CPU_FP16_TO_FP32(b_x0->d)*LM_GGML_CPU_FP16_TO_FP32(b_y1->d),
496
+ LM_GGML_CPU_FP16_TO_FP32(b_x1->d)*LM_GGML_CPU_FP16_TO_FP32(b_y0->d),
497
+ LM_GGML_CPU_FP16_TO_FP32(b_x1->d)*LM_GGML_CPU_FP16_TO_FP32(b_y1->d)
498
+ };
499
+ float32x4_t scale = vld1q_f32(_scale);
500
+
501
+ int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
502
+ int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
503
+
504
+ int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
505
+ int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
506
+
507
+ int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
508
+ int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
509
+
510
+ int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
511
+ int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
512
+ sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
513
+ l1, r1)), l2, r2)), l3, r3))), scale);
514
+ }
515
+
516
+ float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
517
+ float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
518
+
519
+ sumv2 = vaddq_f32(sumv2, summs0);
520
+
521
+ vst1_f32(s, vget_low_f32 (sumv2));
522
+ vst1_f32(s + bs, vget_high_f32(sumv2));
523
+
524
+ return;
525
+ }
526
+ #endif
527
+
528
+ int ib = 0;
529
+ float sumf = 0;
530
+
531
+ #if defined(__ARM_NEON)
532
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
533
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
534
+
535
+ float summs = 0;
536
+
537
+ for (; ib + 1 < nb; ib += 2) {
538
+ const block_q4_1 * LM_GGML_RESTRICT x0 = &x[ib + 0];
539
+ const block_q4_1 * LM_GGML_RESTRICT x1 = &x[ib + 1];
540
+ const block_q8_1 * LM_GGML_RESTRICT y0 = &y[ib + 0];
541
+ const block_q8_1 * LM_GGML_RESTRICT y1 = &y[ib + 1];
542
+
543
+ summs += LM_GGML_CPU_FP16_TO_FP32(x0->m) * LM_GGML_CPU_FP16_TO_FP32(y0->s) + LM_GGML_CPU_FP16_TO_FP32(x1->m) * LM_GGML_CPU_FP16_TO_FP32(y1->s);
544
+
545
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
546
+
547
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
548
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
549
+
550
+ // 4-bit -> 8-bit
551
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
552
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
553
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
554
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
555
+
556
+ // load y
557
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
558
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
559
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
560
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
561
+
562
+ // dot product into int32x4_t
563
+ const int32x4_t p_0 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
564
+ const int32x4_t p_1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
565
+
566
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), LM_GGML_CPU_FP16_TO_FP32(x0->d)*LM_GGML_CPU_FP16_TO_FP32(y0->d));
567
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), LM_GGML_CPU_FP16_TO_FP32(x1->d)*LM_GGML_CPU_FP16_TO_FP32(y1->d));
568
+ }
569
+
570
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
571
+
572
+ #endif
573
+ for (; ib < nb; ++ib) {
574
+ int sumi0 = 0;
575
+ int sumi1 = 0;
576
+
577
+ for (int j = 0; j < qk/2; ++j) {
578
+ const int v0 = (x[ib].qs[j] & 0x0F);
579
+ const int v1 = (x[ib].qs[j] >> 4);
580
+
581
+ sumi0 += (v0 * y[ib].qs[j]);
582
+ sumi1 += (v1 * y[ib].qs[j + qk/2]);
583
+ }
584
+
585
+ int sumi = sumi0 + sumi1;
586
+ sumf += (LM_GGML_CPU_FP16_TO_FP32(x[ib].d)*LM_GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + LM_GGML_CPU_FP16_TO_FP32(x[ib].m)*LM_GGML_CPU_FP16_TO_FP32(y[ib].s);
587
+ }
588
+
589
+ *s = sumf;
590
+ }
591
+
592
+ void lm_ggml_vec_dot_q5_0_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
593
+ const int qk = QK8_0;
594
+ const int nb = n / qk;
595
+
596
+ int ib = 0;
597
+ float sumf = 0;
598
+
599
+ assert(n % qk == 0);
600
+ assert(qk == QK5_0);
601
+ assert(nrc == 1);
602
+ UNUSED(nrc);
603
+ UNUSED(bx);
604
+ UNUSED(by);
605
+ UNUSED(bs);
606
+
607
+ const block_q5_0 * LM_GGML_RESTRICT x = vx;
608
+ const block_q8_0 * LM_GGML_RESTRICT y = vy;
609
+
610
+ #if defined(__ARM_NEON)
611
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
612
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
613
+
614
+ uint32_t qh0;
615
+ uint32_t qh1;
616
+
617
+ uint64_t tmp0[4];
618
+ uint64_t tmp1[4];
619
+
620
+ for (; ib + 1 < nb; ib += 2) {
621
+ const block_q5_0 * LM_GGML_RESTRICT x0 = &x[ib];
622
+ const block_q5_0 * LM_GGML_RESTRICT x1 = &x[ib + 1];
623
+ const block_q8_0 * LM_GGML_RESTRICT y0 = &y[ib];
624
+ const block_q8_0 * LM_GGML_RESTRICT y1 = &y[ib + 1];
625
+
626
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
627
+
628
+ // extract the 5th bit via lookup table ((!b) << 4)
629
+ memcpy(&qh0, x0->qh, sizeof(qh0));
630
+ memcpy(&qh1, x1->qh, sizeof(qh1));
631
+
632
+ tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF];
633
+ tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF];
634
+ tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF];
635
+ tmp0[3] = table_b2b_1[(qh0 >> 24) ];
636
+
637
+ tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF];
638
+ tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF];
639
+ tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF];
640
+ tmp1[3] = table_b2b_1[(qh1 >> 24) ];
641
+
642
+ const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
643
+ const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
644
+ const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
645
+ const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
646
+
647
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
648
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
649
+
650
+ // 4-bit -> 8-bit
651
+ int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
652
+ int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
653
+ int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
654
+ int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
655
+
656
+ // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
657
+ const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0);
658
+ const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0);
659
+ const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1);
660
+ const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1);
661
+
662
+ // load y
663
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
664
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
665
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
666
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
667
+
668
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
669
+ lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
670
+ lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), LM_GGML_CPU_FP16_TO_FP32(x0->d)*LM_GGML_CPU_FP16_TO_FP32(y0->d));
671
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
672
+ lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
673
+ lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), LM_GGML_CPU_FP16_TO_FP32(x1->d)*LM_GGML_CPU_FP16_TO_FP32(y1->d));
674
+ }
675
+
676
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
677
+
678
+ #endif
679
+ for (; ib < nb; ++ib) {
680
+ uint32_t qh;
681
+ memcpy(&qh, x[ib].qh, sizeof(qh));
682
+
683
+ int sumi0 = 0;
684
+ int sumi1 = 0;
685
+
686
+ for (int j = 0; j < qk/2; ++j) {
687
+ const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
688
+ const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
689
+
690
+ const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16);
691
+ const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16);
692
+
693
+ sumi0 += (x0 * y[ib].qs[j]);
694
+ sumi1 += (x1 * y[ib].qs[j + qk/2]);
695
+ }
696
+
697
+ int sumi = sumi0 + sumi1;
698
+ sumf += (LM_GGML_CPU_FP16_TO_FP32(x[ib].d)*LM_GGML_CPU_FP16_TO_FP32(y[ib].d)) * sumi;
699
+ }
700
+
701
+ *s = sumf;
702
+ }
703
+
704
+ void lm_ggml_vec_dot_q5_1_q8_1(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
705
+ const int qk = QK8_1;
706
+ const int nb = n / qk;
707
+
708
+ int ib = 0;
709
+ float sumf = 0;
710
+
711
+ assert(n % qk == 0);
712
+ assert(qk == QK5_1);
713
+ assert(nrc == 1);
714
+ UNUSED(nrc);
715
+ UNUSED(bx);
716
+ UNUSED(by);
717
+ UNUSED(bs);
718
+
719
+ const block_q5_1 * LM_GGML_RESTRICT x = vx;
720
+ const block_q8_1 * LM_GGML_RESTRICT y = vy;
721
+
722
+ #if defined(__ARM_NEON)
723
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
724
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
725
+
726
+ float summs0 = 0.0f;
727
+ float summs1 = 0.0f;
728
+
729
+ uint32_t qh0;
730
+ uint32_t qh1;
731
+
732
+ uint64_t tmp0[4];
733
+ uint64_t tmp1[4];
734
+
735
+ for (; ib + 1 < nb; ib += 2) {
736
+ const block_q5_1 * LM_GGML_RESTRICT x0 = &x[ib];
737
+ const block_q5_1 * LM_GGML_RESTRICT x1 = &x[ib + 1];
738
+ const block_q8_1 * LM_GGML_RESTRICT y0 = &y[ib];
739
+ const block_q8_1 * LM_GGML_RESTRICT y1 = &y[ib + 1];
740
+
741
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
742
+
743
+ summs0 += LM_GGML_CPU_FP16_TO_FP32(x0->m) * LM_GGML_CPU_FP16_TO_FP32(y0->s);
744
+ summs1 += LM_GGML_CPU_FP16_TO_FP32(x1->m) * LM_GGML_CPU_FP16_TO_FP32(y1->s);
745
+
746
+ // extract the 5th bit via lookup table ((b) << 4)
747
+ memcpy(&qh0, x0->qh, sizeof(qh0));
748
+ memcpy(&qh1, x1->qh, sizeof(qh1));
749
+
750
+ tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF];
751
+ tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF];
752
+ tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF];
753
+ tmp0[3] = table_b2b_0[(qh0 >> 24) ];
754
+
755
+ tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF];
756
+ tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF];
757
+ tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF];
758
+ tmp1[3] = table_b2b_0[(qh1 >> 24) ];
759
+
760
+ const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
761
+ const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
762
+ const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
763
+ const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
764
+
765
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
766
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
767
+
768
+ // 4-bit -> 8-bit
769
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
770
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
771
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
772
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
773
+
774
+ // add high bit
775
+ const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0);
776
+ const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0);
777
+ const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1);
778
+ const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1);
779
+
780
+ // load y
781
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
782
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
783
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
784
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
785
+
786
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
787
+ lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
788
+ lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), LM_GGML_CPU_FP16_TO_FP32(x0->d)*LM_GGML_CPU_FP16_TO_FP32(y0->d));
789
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
790
+ lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
791
+ lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), LM_GGML_CPU_FP16_TO_FP32(x1->d)*LM_GGML_CPU_FP16_TO_FP32(y1->d));
792
+ }
793
+
794
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
795
+
796
+ #endif
797
+ for (; ib < nb; ++ib) {
798
+ uint32_t qh;
799
+ memcpy(&qh, x[ib].qh, sizeof(qh));
800
+
801
+ int sumi0 = 0;
802
+ int sumi1 = 0;
803
+
804
+ for (int j = 0; j < qk/2; ++j) {
805
+ const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
806
+ const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
807
+
808
+ const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0;
809
+ const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1;
810
+
811
+ sumi0 += (x0 * y[ib].qs[j]);
812
+ sumi1 += (x1 * y[ib].qs[j + qk/2]);
813
+ }
814
+
815
+ int sumi = sumi0 + sumi1;
816
+ sumf += (LM_GGML_CPU_FP16_TO_FP32(x[ib].d)*LM_GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + LM_GGML_CPU_FP16_TO_FP32(x[ib].m)*LM_GGML_CPU_FP16_TO_FP32(y[ib].s);
817
+ }
818
+
819
+ *s = sumf;
820
+ }
821
+
822
+ void lm_ggml_vec_dot_q8_0_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
823
+ const int qk = QK8_0;
824
+ const int nb = n / qk;
825
+
826
+ assert(n % qk == 0);
827
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
828
+ assert((nrc == 2) || (nrc == 1));
829
+ #else
830
+ assert(nrc == 1);
831
+ #endif
832
+ UNUSED(nrc);
833
+ UNUSED(bx);
834
+ UNUSED(by);
835
+ UNUSED(bs);
836
+
837
+ const block_q8_0 * LM_GGML_RESTRICT x = vx;
838
+ const block_q8_0 * LM_GGML_RESTRICT y = vy;
839
+
840
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
841
+ if (nrc == 2) {
842
+ const block_q8_0 * LM_GGML_RESTRICT vx0 = vx;
843
+ const block_q8_0 * LM_GGML_RESTRICT vx1 = (const block_q8_0 *) ((const uint8_t*)vx + bx);
844
+ const block_q8_0 * LM_GGML_RESTRICT vy0 = vy;
845
+ const block_q8_0 * LM_GGML_RESTRICT vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);
846
+
847
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
848
+
849
+ for (int i = 0; i < nb; i++) {
850
+ const block_q8_0 * LM_GGML_RESTRICT b_x0 = &vx0[i];
851
+ const block_q8_0 * LM_GGML_RESTRICT b_y0 = &vy0[i];
852
+
853
+ const block_q8_0 * LM_GGML_RESTRICT b_x1 = &vx1[i];
854
+ const block_q8_0 * LM_GGML_RESTRICT b_y1 = &vy1[i];
855
+
856
+ const int8x16_t x0_l = vld1q_s8(b_x0->qs);
857
+ const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16);
858
+ const int8x16_t x1_l = vld1q_s8(b_x1->qs);
859
+ const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16);
860
+
861
+ // load y
862
+ const int8x16_t y0_l = vld1q_s8(b_y0->qs);
863
+ const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
864
+ const int8x16_t y1_l = vld1q_s8(b_y1->qs);
865
+ const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
866
+
867
+ float32_t _scale[4] = {
868
+ LM_GGML_CPU_FP16_TO_FP32(b_x0->d)*LM_GGML_CPU_FP16_TO_FP32(b_y0->d),
869
+ LM_GGML_CPU_FP16_TO_FP32(b_x0->d)*LM_GGML_CPU_FP16_TO_FP32(b_y1->d),
870
+ LM_GGML_CPU_FP16_TO_FP32(b_x1->d)*LM_GGML_CPU_FP16_TO_FP32(b_y0->d),
871
+ LM_GGML_CPU_FP16_TO_FP32(b_x1->d)*LM_GGML_CPU_FP16_TO_FP32(b_y1->d)
872
+ };
873
+ float32x4_t scale = vld1q_f32(_scale);
874
+
875
+ int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
876
+ int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
877
+
878
+ int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
879
+ int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
880
+
881
+ int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
882
+ int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
883
+
884
+ int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
885
+ int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
886
+
887
+ sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
888
+ l1, r1)), l2, r2)), l3, r3))), scale);
889
+ }
890
+
891
+ float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
892
+ float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
893
+
894
+ vst1_f32(s, vget_low_f32 (sumv2));
895
+ vst1_f32(s + bs, vget_high_f32(sumv2));
896
+
897
+ return;
898
+ }
899
+ #endif
900
+
901
+ int ib = 0;
902
+ float sumf = 0;
903
+
904
+ #if defined(__ARM_FEATURE_SVE)
905
+ svfloat32_t sumv0 = svdup_n_f32(0.0f);
906
+ svfloat32_t sumv1 = svdup_n_f32(0.0f);
907
+
908
+ const int vector_length = lm_ggml_cpu_get_sve_cnt()*8;
909
+
910
+ //VLA Implemenation for SVE
911
+ switch (vector_length) {
912
+ case 128:
913
+ {
914
+ // predicate for activating lanes for 16 Int8 elements
915
+ const svbool_t ph16 = svptrue_pat_b8 (SV_VL16);
916
+ const svbool_t pl16 = svptrue_pat_b32(SV_VL4);
917
+
918
+ for (; ib + 1 < nb; ib += 2) {
919
+ const block_q8_0 * LM_GGML_RESTRICT x0 = &x[ib + 0];
920
+ const block_q8_0 * LM_GGML_RESTRICT x1 = &x[ib + 1];
921
+ const block_q8_0 * LM_GGML_RESTRICT y0 = &y[ib + 0];
922
+ const block_q8_0 * LM_GGML_RESTRICT y1 = &y[ib + 1];
923
+
924
+ // load x
925
+ const svint8_t qx0_0 = svld1_s8(ph16, x0->qs);
926
+ const svint8_t qx0_1 = svld1_s8(ph16, x0->qs+16);
927
+ const svint8_t qx1_0 = svld1_s8(ph16, x1->qs);
928
+ const svint8_t qx1_1 = svld1_s8(ph16, x1->qs+16);
929
+
930
+ // load y
931
+ const svint8_t qy0_0 = svld1_s8(ph16, y0->qs);
932
+ const svint8_t qy0_1 = svld1_s8(ph16, y0->qs+16);
933
+ const svint8_t qy1_0 = svld1_s8(ph16, y1->qs);
934
+ const svint8_t qy1_1 = svld1_s8(ph16, y1->qs+16);
935
+
936
+ sumv0 = svmla_n_f32_x(pl16, sumv0, svcvt_f32_s32_x(pl16, svadd_x(pl16,
937
+ svdot_s32(svdup_n_s32(0), qx0_0, qy0_0),
938
+ svdot_s32(svdup_n_s32(0), qx0_1, qy0_1))), LM_GGML_CPU_FP16_TO_FP32(x0->d)*LM_GGML_CPU_FP16_TO_FP32(y0->d));
939
+ sumv1 = svmla_n_f32_x(pl16, sumv1, svcvt_f32_s32_x(pl16, svadd_x(pl16,
940
+ svdot_s32(svdup_n_s32(0), qx1_0, qy1_0),
941
+ svdot_s32(svdup_n_s32(0), qx1_1, qy1_1))), LM_GGML_CPU_FP16_TO_FP32(x1->d)*LM_GGML_CPU_FP16_TO_FP32(y1->d));
942
+ }
943
+
944
+ sumf = svaddv_f32(pl16, svadd_f32_x(pl16, sumv0, sumv1));
945
+ } break;
946
+ case 256:
947
+ {
948
+ //printf("sve256");
949
+ for (; ib + 1 < nb; ib += 2) {
950
+ const block_q8_0 * LM_GGML_RESTRICT x0 = &x[ib + 0];
951
+ const block_q8_0 * LM_GGML_RESTRICT x1 = &x[ib + 1];
952
+ const block_q8_0 * LM_GGML_RESTRICT y0 = &y[ib + 0];
953
+ const block_q8_0 * LM_GGML_RESTRICT y1 = &y[ib + 1];
954
+
955
+ // load x
956
+ const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
957
+ const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
958
+
959
+ // load y
960
+ const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
961
+ const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
962
+
963
+ sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),
964
+ svdot_s32(svdup_n_s32(0), qx0, qy0)), LM_GGML_CPU_FP16_TO_FP32(x0->d)*LM_GGML_CPU_FP16_TO_FP32(y0->d));
965
+ sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),
966
+ svdot_s32(svdup_n_s32(0), qx1, qy1)), LM_GGML_CPU_FP16_TO_FP32(x1->d)*LM_GGML_CPU_FP16_TO_FP32(y1->d));
967
+ }
968
+
969
+ sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
970
+ } break;
971
+ case 512:
972
+ {
973
+ // predicate for activating high 256 bit
974
+ const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
975
+ // predicate for activating low 256 bit
976
+ const svbool_t pl32 = svnot_b_z(svptrue_b8(), ph32);
977
+
978
+ // predicate for activating high lanes for 8 float32 elements
979
+ const svbool_t ph8 = svptrue_pat_b32(SV_VL8);
980
+ // predicate for activating low lanes for 8 float32 elements
981
+ const svbool_t pl8 = svnot_b_z(svptrue_b32(), ph8);
982
+
983
+ svfloat32_t sumv00 = svdup_n_f32(0.0f);
984
+
985
+ for (; ib + 1 < nb; ib += 2) {
986
+ const block_q8_0 * LM_GGML_RESTRICT x0 = &x[ib + 0];
987
+ const block_q8_0 * LM_GGML_RESTRICT x1 = &x[ib + 1];
988
+ const block_q8_0 * LM_GGML_RESTRICT y0 = &y[ib + 0];
989
+ const block_q8_0 * LM_GGML_RESTRICT y1 = &y[ib + 1];
990
+
991
+ //load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits
992
+ // and add them to make one 64 element vector
993
+ // load x
994
+ const svint8_t qx_32 = svld1_s8(ph32, x0->qs);
995
+ svint8_t qx_64 = svld1_s8(pl32, x0->qs + 2);
996
+
997
+ qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64);
998
+
999
+ // load y
1000
+ const svint8_t qy_32 = svld1_s8(ph32, y0->qs);
1001
+ svint8_t qy_64 = svld1_s8(pl32, y0->qs + 2);
1002
+
1003
+ qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64);
1004
+
1005
+ // scale creation
1006
+ const float32_t deq1 = LM_GGML_CPU_FP16_TO_FP32(x0->d)*LM_GGML_CPU_FP16_TO_FP32(y0->d);
1007
+ const float32_t deq2 = LM_GGML_CPU_FP16_TO_FP32(x1->d)*LM_GGML_CPU_FP16_TO_FP32(y1->d);
1008
+
1009
+ // duplicate deq1 in first half of vector and deq2 in second half of vector
1010
+ const svfloat32_t temp = svdup_f32_m(svdup_f32_z(ph8, deq1), pl8, deq2);
1011
+
1012
+ const svfloat32_t sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64));
1013
+
1014
+ sumv00 = svmla_f32_m(svptrue_b32(), sumv00, sumvt, temp);
1015
+ }
1016
+
1017
+ sumf = svaddv_f32(svptrue_b32(), sumv00);
1018
+ break;
1019
+ }
1020
+ default:
1021
+ assert(false && "Unsupported vector length");
1022
+ break;
1023
+ }
1024
+ #elif defined(__ARM_NEON)
1025
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
1026
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
1027
+
1028
+ for (; ib + 1 < nb; ib += 2) {
1029
+ const block_q8_0 * LM_GGML_RESTRICT x0 = &x[ib + 0];
1030
+ const block_q8_0 * LM_GGML_RESTRICT x1 = &x[ib + 1];
1031
+ const block_q8_0 * LM_GGML_RESTRICT y0 = &y[ib + 0];
1032
+ const block_q8_0 * LM_GGML_RESTRICT y1 = &y[ib + 1];
1033
+
1034
+ const int8x16_t x0_0 = vld1q_s8(x0->qs);
1035
+ const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
1036
+ const int8x16_t x1_0 = vld1q_s8(x1->qs);
1037
+ const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
1038
+
1039
+ // load y
1040
+ const int8x16_t y0_0 = vld1q_s8(y0->qs);
1041
+ const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
1042
+ const int8x16_t y1_0 = vld1q_s8(y1->qs);
1043
+ const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
1044
+
1045
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
1046
+ lm_ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
1047
+ lm_ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), LM_GGML_CPU_FP16_TO_FP32(x0->d)*LM_GGML_CPU_FP16_TO_FP32(y0->d));
1048
+
1049
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
1050
+ lm_ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
1051
+ lm_ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), LM_GGML_CPU_FP16_TO_FP32(x1->d)*LM_GGML_CPU_FP16_TO_FP32(y1->d));
1052
+ }
1053
+
1054
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
1055
+ #endif
1056
+ for (; ib < nb; ++ib) {
1057
+ int sumi = 0;
1058
+
1059
+ for (int j = 0; j < qk; j++) {
1060
+ sumi += x[ib].qs[j]*y[ib].qs[j];
1061
+ }
1062
+
1063
+ sumf += sumi*(LM_GGML_CPU_FP16_TO_FP32(x[ib].d)*LM_GGML_CPU_FP16_TO_FP32(y[ib].d));
1064
+ }
1065
+
1066
+ *s = sumf;
1067
+ }
1068
+
1069
+ void lm_ggml_vec_dot_tq1_0_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
1070
+ assert(nrc == 1);
1071
+ UNUSED(nrc);
1072
+ UNUSED(bx);
1073
+ UNUSED(by);
1074
+ UNUSED(bs);
1075
+
1076
+ const block_tq1_0 * LM_GGML_RESTRICT x = vx;
1077
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
1078
+
1079
+ const int nb = n / QK_K;
1080
+
1081
+ #if defined(__ARM_NEON)
1082
+ float sumf = 0.0f;
1083
+
1084
+ uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
1085
+
1086
+ const uint8x16_t shift = vld1q_u8(k_shift);
1087
+
1088
+ for (int i = 0; i < nb; ++i) {
1089
+ #if defined(__ARM_FEATURE_DOTPROD)
1090
+ int32x4_t sumi0 = vdupq_n_s32(0);
1091
+ int32x4_t sumi1 = vdupq_n_s32(0);
1092
+ #else
1093
+ int16x8_t sumi0 = vdupq_n_s16(0);
1094
+ int16x8_t sumi1 = vdupq_n_s16(0);
1095
+ #endif
1096
+
1097
+ // first 32 bytes of 5 elements
1098
+ {
1099
+ uint8x16_t qx0 = vld1q_u8(x[i].qs + 0);
1100
+ uint8x16_t qx1 = vld1q_u8(x[i].qs + 16);
1101
+ uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3));
1102
+ uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3));
1103
+ uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9));
1104
+ uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9));
1105
+ uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27));
1106
+ uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27));
1107
+ uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81));
1108
+ uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81));
1109
+
1110
+ // multiply by 3 and keep the 2 bits above 8 bits
1111
+ int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
1112
+ int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
1113
+ int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
1114
+ int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
1115
+ int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
1116
+ int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
1117
+ int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6));
1118
+ int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6));
1119
+ int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6));
1120
+ int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6));
1121
+
1122
+ const int8x16_t qy0 = vld1q_s8(y[i].qs + 0);
1123
+ const int8x16_t qy1 = vld1q_s8(y[i].qs + 16);
1124
+ const int8x16_t qy2 = vld1q_s8(y[i].qs + 32);
1125
+ const int8x16_t qy3 = vld1q_s8(y[i].qs + 48);
1126
+ const int8x16_t qy4 = vld1q_s8(y[i].qs + 64);
1127
+ const int8x16_t qy5 = vld1q_s8(y[i].qs + 80);
1128
+ const int8x16_t qy6 = vld1q_s8(y[i].qs + 96);
1129
+ const int8x16_t qy7 = vld1q_s8(y[i].qs + 112);
1130
+ const int8x16_t qy8 = vld1q_s8(y[i].qs + 128);
1131
+ const int8x16_t qy9 = vld1q_s8(y[i].qs + 144);
1132
+
1133
+ #if defined(__ARM_FEATURE_DOTPROD)
1134
+ sumi0 = vdotq_s32(sumi0, sqx0, qy0);
1135
+ sumi1 = vdotq_s32(sumi1, sqx1, qy1);
1136
+ sumi0 = vdotq_s32(sumi0, sqx2, qy2);
1137
+ sumi1 = vdotq_s32(sumi1, sqx3, qy3);
1138
+ sumi0 = vdotq_s32(sumi0, sqx4, qy4);
1139
+ sumi1 = vdotq_s32(sumi1, sqx5, qy5);
1140
+ sumi0 = vdotq_s32(sumi0, sqx6, qy6);
1141
+ sumi1 = vdotq_s32(sumi1, sqx7, qy7);
1142
+ sumi0 = vdotq_s32(sumi0, sqx8, qy8);
1143
+ sumi1 = vdotq_s32(sumi1, sqx9, qy9);
1144
+ #else
1145
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
1146
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
1147
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
1148
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
1149
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
1150
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
1151
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
1152
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
1153
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
1154
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
1155
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
1156
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
1157
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));
1158
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
1159
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
1160
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
1161
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx8), vget_low_s8(qy8));
1162
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8));
1163
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9));
1164
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9));
1165
+ #endif
1166
+ }
1167
+
1168
+ // last 16 bytes of 5-element, along with the 4 bytes of 4 elements
1169
+ {
1170
+ uint8x16_t qx0 = vld1q_u8(x[i].qs + 32);
1171
+ uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3));
1172
+ uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9));
1173
+ uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27));
1174
+ uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81));
1175
+ uint32_t qh;
1176
+ memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
1177
+ uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh));
1178
+ qx5 = vmulq_u8(qx5, shift);
1179
+
1180
+ // multiply by 3 and keep the 2 bits above 8 bits
1181
+ int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
1182
+ int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
1183
+ int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
1184
+ int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
1185
+ int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
1186
+ int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
1187
+
1188
+ const int8x16_t qy0 = vld1q_s8(y[i].qs + 160);
1189
+ const int8x16_t qy1 = vld1q_s8(y[i].qs + 176);
1190
+ const int8x16_t qy2 = vld1q_s8(y[i].qs + 192);
1191
+ const int8x16_t qy3 = vld1q_s8(y[i].qs + 208);
1192
+ const int8x16_t qy4 = vld1q_s8(y[i].qs + 224);
1193
+ const int8x16_t qy5 = vld1q_s8(y[i].qs + 240);
1194
+
1195
+ #if defined(__ARM_FEATURE_DOTPROD)
1196
+ sumi0 = vdotq_s32(sumi0, sqx0, qy0);
1197
+ sumi1 = vdotq_s32(sumi1, sqx1, qy1);
1198
+ sumi0 = vdotq_s32(sumi0, sqx2, qy2);
1199
+ sumi1 = vdotq_s32(sumi1, sqx3, qy3);
1200
+ sumi0 = vdotq_s32(sumi0, sqx4, qy4);
1201
+ sumi1 = vdotq_s32(sumi1, sqx5, qy5);
1202
+ #else
1203
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
1204
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
1205
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
1206
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
1207
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
1208
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
1209
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
1210
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
1211
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
1212
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
1213
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
1214
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
1215
+ #endif
1216
+ }
1217
+
1218
+ const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
1219
+ const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
1220
+
1221
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1222
+
1223
+ #if defined(__ARM_FEATURE_DOTPROD)
1224
+ sumi0 = vaddq_s32(sumi0, sumi1);
1225
+ sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
1226
+
1227
+ sumf += d * (float) vaddvq_s32(sumi0);
1228
+ #else
1229
+ sumi0 = vaddq_s16(sumi0, sumi1);
1230
+ sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
1231
+
1232
+ sumf += d * (float) vaddlvq_s16(sumi0);
1233
+ #endif
1234
+ }
1235
+
1236
+ *s = sumf;
1237
+
1238
+ #else
1239
+ const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
1240
+
1241
+ float sumf = 0.0f;
1242
+
1243
+ for (int i = 0; i < nb; ++i) {
1244
+ int sum = 0;
1245
+
1246
+ for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) {
1247
+ for (size_t l = 0; l < 5; ++l) {
1248
+ for (size_t m = 0; m < 32; ++m) {
1249
+ uint8_t q = x[i].qs[j + m] * pow3[l];
1250
+ uint16_t xi = ((uint16_t) q * 3) >> 8;
1251
+ sum += (xi - 1) * y[i].qs[j*5 + l*32 + m];
1252
+ }
1253
+ }
1254
+ }
1255
+ for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) {
1256
+ for (size_t l = 0; l < 5; ++l) {
1257
+ for (size_t m = 0; m < 16; ++m) {
1258
+ uint8_t q = x[i].qs[j + m] * pow3[l];
1259
+ uint16_t xi = ((uint16_t) q * 3) >> 8;
1260
+ sum += (xi - 1) * y[i].qs[j*5 + l*16 + m];
1261
+ }
1262
+ }
1263
+ }
1264
+
1265
+ for (size_t l = 0; l < 4; ++l) {
1266
+ for (size_t j = 0; j < sizeof(x->qh); ++j) {
1267
+ uint8_t q = x[i].qh[j] * pow3[l];
1268
+ uint16_t xi = ((uint16_t) q * 3) >> 8;
1269
+ sum += (xi - 1) * y[i].qs[sizeof(x->qs)*5 + l*sizeof(x->qh) + j];
1270
+ }
1271
+ }
1272
+
1273
+ sumf += (float) sum * (LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d);
1274
+ }
1275
+
1276
+ *s = sumf;
1277
+ #endif
1278
+ }
1279
+
1280
+ void lm_ggml_vec_dot_tq2_0_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
1281
+ assert(nrc == 1);
1282
+ UNUSED(nrc);
1283
+ UNUSED(bx);
1284
+ UNUSED(by);
1285
+ UNUSED(bs);
1286
+
1287
+ const block_tq2_0 * LM_GGML_RESTRICT x = vx;
1288
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
1289
+
1290
+ const int nb = n / QK_K;
1291
+
1292
+ #if defined(__ARM_NEON)
1293
+ float sumf = 0.0f;
1294
+
1295
+ const uint8x16_t m3 = vdupq_n_u8(3);
1296
+
1297
+ for (int i = 0; i < nb; ++i) {
1298
+ #if defined(__ARM_FEATURE_DOTPROD)
1299
+ int32x4_t sumi0 = vdupq_n_s32(0);
1300
+ int32x4_t sumi1 = vdupq_n_s32(0);
1301
+ #else
1302
+ int16x8_t sumi0 = vdupq_n_s16(0);
1303
+ int16x8_t sumi1 = vdupq_n_s16(0);
1304
+ #endif
1305
+
1306
+ for (size_t j = 0; j < sizeof(x->qs); j += 32) {
1307
+ uint8x16_t qx0 = vld1q_u8(x[i].qs + j);
1308
+ uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16);
1309
+ uint8x16_t qx2 = vshrq_n_u8(qx0, 2);
1310
+ uint8x16_t qx3 = vshrq_n_u8(qx1, 2);
1311
+ uint8x16_t qx4 = vshrq_n_u8(qx0, 4);
1312
+ uint8x16_t qx5 = vshrq_n_u8(qx1, 4);
1313
+ uint8x16_t qx6 = vshrq_n_u8(qx0, 6);
1314
+ uint8x16_t qx7 = vshrq_n_u8(qx1, 6);
1315
+
1316
+ int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3));
1317
+ int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3));
1318
+ int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3));
1319
+ int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3));
1320
+ int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3));
1321
+ int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3));
1322
+ int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3));
1323
+ int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3));
1324
+
1325
+ const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 + 0);
1326
+ const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 + 16);
1327
+ const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 + 32);
1328
+ const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 + 48);
1329
+ const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 + 64);
1330
+ const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 + 80);
1331
+ const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96);
1332
+ const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112);
1333
+
1334
+ #if defined(__ARM_FEATURE_DOTPROD)
1335
+ sumi0 = vdotq_s32(sumi0, sqx0, qy0);
1336
+ sumi1 = vdotq_s32(sumi1, sqx1, qy1);
1337
+ sumi0 = vdotq_s32(sumi0, sqx2, qy2);
1338
+ sumi1 = vdotq_s32(sumi1, sqx3, qy3);
1339
+ sumi0 = vdotq_s32(sumi0, sqx4, qy4);
1340
+ sumi1 = vdotq_s32(sumi1, sqx5, qy5);
1341
+ sumi0 = vdotq_s32(sumi0, sqx6, qy6);
1342
+ sumi1 = vdotq_s32(sumi1, sqx7, qy7);
1343
+ #else
1344
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
1345
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
1346
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
1347
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
1348
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
1349
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
1350
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
1351
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
1352
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
1353
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
1354
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
1355
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
1356
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));
1357
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
1358
+ sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
1359
+ sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
1360
+ #endif
1361
+ }
1362
+
1363
+ const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
1364
+ const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
1365
+
1366
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1367
+
1368
+ #if defined(__ARM_FEATURE_DOTPROD)
1369
+ sumi0 = vaddq_s32(sumi0, sumi1);
1370
+ sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
1371
+
1372
+ sumf += d * (float) vaddvq_s32(sumi0);
1373
+ #else
1374
+ sumi0 = vaddq_s16(sumi0, sumi1);
1375
+ sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
1376
+
1377
+ sumf += d * (float) vaddlvq_s16(sumi0);
1378
+ #endif
1379
+ }
1380
+
1381
+ *s = sumf;
1382
+
1383
+ #else
1384
+ float sumf = 0.0f;
1385
+
1386
+ for (int i = 0; i < nb; ++i) {
1387
+ int32_t sumi = 0;
1388
+
1389
+ for (size_t j = 0; j < sizeof(x->qs); j += 32) {
1390
+ for (size_t l = 0; l < 4; ++l) {
1391
+ for (size_t k = 0; k < 32; ++k) {
1392
+ sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].qs[j + k] >> (l*2)) & 3) - 1);
1393
+ }
1394
+ }
1395
+ }
1396
+
1397
+ const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
1398
+
1399
+ sumf += (float) sumi * d;
1400
+ }
1401
+
1402
+ *s = sumf;
1403
+ #endif
1404
+ }
1405
+
1406
+ void lm_ggml_vec_dot_q2_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
1407
+ assert(nrc == 1);
1408
+ UNUSED(nrc);
1409
+ UNUSED(bx);
1410
+ UNUSED(by);
1411
+ UNUSED(bs);
1412
+
1413
+ const block_q2_K * LM_GGML_RESTRICT x = vx;
1414
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
1415
+
1416
+ const int nb = n / QK_K;
1417
+
1418
+ #ifdef __ARM_FEATURE_SVE
1419
+ const int vector_length = svcntb()*8;
1420
+ const svuint8_t m3s = svdup_n_u8(0x3);
1421
+ const svuint32_t m4s = svdup_n_u32(0xF);
1422
+ const svint32_t vzero_sv = svdup_n_s32(0);
1423
+ svfloat32_t acc_sum = svdup_n_f32(0);
1424
+ svbool_t pred_s32 = svptrue_pat_b32(SV_VL4);
1425
+
1426
+ switch (vector_length) {
1427
+ case 128:
1428
+ for (int i = 0; i < nb; ++i) {
1429
+ const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
1430
+ svfloat32_t d_broad = svdup_n_f32((float32_t)d);
1431
+ const float dmin = -y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].dmin);
1432
+ svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin);
1433
+
1434
+ const uint8_t * LM_GGML_RESTRICT q2 = x[i].qs;
1435
+ const int8_t * LM_GGML_RESTRICT q8_sv = y[i].qs;
1436
+ const uint8_t * LM_GGML_RESTRICT sc = x[i].scales;
1437
+
1438
+ svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc);
1439
+ const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
1440
+
1441
+ mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+4);
1442
+ const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
1443
+
1444
+ svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums);
1445
+ svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+4);
1446
+
1447
+ const svint32_t s0 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_2, q8sums_sv_2));
1448
+
1449
+ mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+8);
1450
+ const svint32_t mins_sv_3 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
1451
+
1452
+ mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+12);
1453
+ const svint32_t mins_sv_4 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
1454
+
1455
+ q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums+8);
1456
+ q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+12);
1457
+
1458
+ svint32_t s1 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_3, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_4, q8sums_sv_2));
1459
+
1460
+ svfloat32_t temp = svcvt_f32_s32_x(svptrue_b32(), svadd_s32_x(svptrue_b32(), s0, s1));
1461
+
1462
+ acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, temp, dmin_broad);
1463
+
1464
+ svint32_t sumi1 = svdup_n_s32(0);
1465
+
1466
+ {
1467
+ const svuint8_t q2bits_1 = svld1_u8(svptrue_b8(), q2);
1468
+ svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_1, m3s));
1469
+ svint8_t q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1470
+ const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc), m4s));
1471
+
1472
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 0));
1473
+
1474
+ const svuint8_t q2bits_3 = svld1_u8(svptrue_b8(), q2+16);
1475
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_3, m3s));
1476
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1477
+
1478
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 1));
1479
+
1480
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 2), m3s));
1481
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1482
+
1483
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 2));
1484
+
1485
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 2), m3s));
1486
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1487
+
1488
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 3));
1489
+
1490
+
1491
+ const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+4), m4s));
1492
+
1493
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 4), m3s));
1494
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1495
+
1496
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 0));
1497
+
1498
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 4), m3s));
1499
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1500
+
1501
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 1));
1502
+
1503
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 6), m3s));
1504
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1505
+
1506
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 2));
1507
+
1508
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 6), m3s));
1509
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1510
+
1511
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 3));
1512
+
1513
+ //-------------------------------
1514
+
1515
+ q2 += 32;
1516
+ const svint32_t scales_sv_2 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+8), m4s));
1517
+ const svuint8_t q2bits_2 = svld1_u8(svptrue_b8(), q2);
1518
+
1519
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_2, m3s));
1520
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1521
+
1522
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 0));
1523
+
1524
+ const svuint8_t q2bits_4 = svld1_u8(svptrue_b8(), q2+16);
1525
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_4, m3s));
1526
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1527
+
1528
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 1));
1529
+
1530
+
1531
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 2), m3s));
1532
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1533
+
1534
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 2));
1535
+
1536
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 2), m3s));
1537
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1538
+
1539
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 3));
1540
+
1541
+
1542
+ const svint32_t scales_sv_3 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+12), m4s));
1543
+
1544
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 4), m3s));
1545
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1546
+
1547
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 0));
1548
+
1549
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 4), m3s));
1550
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1551
+
1552
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 1));
1553
+
1554
+
1555
+
1556
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 6), m3s));
1557
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1558
+
1559
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 2));
1560
+
1561
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 6), m3s));
1562
+ q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1563
+
1564
+ sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 3));
1565
+ }
1566
+ acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, svcvt_f32_s32_x(svptrue_b32(), sumi1), d_broad);
1567
+ }
1568
+ *s = svaddv_f32(svptrue_b32(), acc_sum);
1569
+ break;
1570
+
1571
+ case 256:
1572
+ case 512:
1573
+ for (int i = 0; i < nb; ++i) {
1574
+ const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
1575
+ svfloat32_t d_broad = svdup_n_f32((float32_t)d);
1576
+ const float dmin = -y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].dmin);
1577
+ svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin);
1578
+
1579
+ const uint8_t * LM_GGML_RESTRICT q2 = x[i].qs;
1580
+ const int8_t * LM_GGML_RESTRICT q8_sv = y[i].qs;
1581
+ const uint8_t * LM_GGML_RESTRICT sc = x[i].scales;
1582
+
1583
+ const svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc); sc += 8;
1584
+ const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, m4s));
1585
+ const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, 4));
1586
+ svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums);
1587
+
1588
+ const svuint32_t mins_and_scales_sve_1 = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc);
1589
+ const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, m4s));
1590
+ const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, 4));
1591
+
1592
+ svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums+8);
1593
+
1594
+ svfloat32_t temp = svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_2, q8sums_sv_2)));
1595
+
1596
+ acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, temp, dmin_broad);
1597
+
1598
+ svint32_t sumi1 = svdup_n_s32(0);
1599
+
1600
+ {
1601
+ const svuint8_t q2bits_1 = svld1_u8(svptrue_pat_b8(SV_VL32), q2);
1602
+ svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_1, m3s));
1603
+ svint8_t q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1604
+
1605
+ svint32_t scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 0), svdup_lane_s32(scales_sv, 1));
1606
+ sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
1607
+
1608
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 2), m3s));
1609
+ q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1610
+
1611
+ svint32_t scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 2), svdup_lane_s32(scales_sv, 3));
1612
+ sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(svdup_n_s32(0), q2bytes_sv, q8bytes_sv), scale_2);
1613
+
1614
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 4), m3s));
1615
+ q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1616
+
1617
+ scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 4), svdup_lane_s32(scales_sv, 5));
1618
+ sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
1619
+
1620
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 6), m3s));
1621
+ q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1622
+
1623
+ scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 6), svdup_lane_s32(scales_sv, 7));
1624
+ sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
1625
+
1626
+ q2 += 32;
1627
+
1628
+ const svuint8_t q2bits_2 = svld1_u8(svptrue_pat_b8(SV_VL32), q2);
1629
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_2, m3s));
1630
+ q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1631
+
1632
+ scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 0), svdup_lane_s32(scales_sv_1, 1));
1633
+ sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
1634
+
1635
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 2), m3s));
1636
+ q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1637
+
1638
+ scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 2), svdup_lane_s32(scales_sv_1, 3));
1639
+ sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
1640
+
1641
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 4), m3s));
1642
+ q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1643
+
1644
+ scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 4), svdup_lane_s32(scales_sv_1, 5));
1645
+ sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
1646
+
1647
+ q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 6), m3s));
1648
+ q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1649
+
1650
+ scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 6), svdup_lane_s32(scales_sv_1, 7));
1651
+ sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
1652
+ }
1653
+ acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), sumi1), d_broad);
1654
+ }
1655
+ *s = svaddv_f32(svptrue_pat_b32(SV_VL8), acc_sum);
1656
+ break;
1657
+
1658
+ default:
1659
+ assert(false && "Unsupported vector length");
1660
+ break;
1661
+ }
1662
+
1663
+ #elif __ARM_NEON
1664
+ const uint8x16_t m3 = vdupq_n_u8(0x3);
1665
+ const uint8x16_t m4 = vdupq_n_u8(0xF);
1666
+
1667
+ const int32x4_t vzero = vdupq_n_s32(0);
1668
+
1669
+ lm_ggml_int8x16x2_t q2bytes;
1670
+ uint8_t aux[16];
1671
+
1672
+ float sum = 0;
1673
+
1674
+ for (int i = 0; i < nb; ++i) {
1675
+ const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
1676
+ const float dmin = -y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].dmin);
1677
+
1678
+ const uint8_t * LM_GGML_RESTRICT q2 = x[i].qs;
1679
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
1680
+ const uint8_t * LM_GGML_RESTRICT sc = x[i].scales;
1681
+
1682
+ const uint8x16_t mins_and_scales = vld1q_u8(sc);
1683
+ const uint8x16_t scales = vandq_u8(mins_and_scales, m4);
1684
+ vst1q_u8(aux, scales);
1685
+
1686
+ const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
1687
+ const lm_ggml_int16x8x2_t q8sums = lm_ggml_vld1q_s16_x2(y[i].bsums);
1688
+ const lm_ggml_int16x8x2_t mins16 = {{vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}};
1689
+ const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
1690
+ vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
1691
+ const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
1692
+ vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1])));
1693
+ sum += dmin * vaddvq_s32(vaddq_s32(s0, s1));
1694
+
1695
+ int isum = 0;
1696
+ int is = 0;
1697
+
1698
+ // We use this macro instead of a function call because for some reason
1699
+ // the code runs 2-3% slower, even if the function is declared inline
1700
+ #define MULTIPLY_ACCUM_WITH_SCALE(index)\
1701
+ isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
1702
+ isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
1703
+
1704
+ #define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
1705
+ q8bytes = lm_ggml_vld1q_s8_x2(q8); q8 += 32;\
1706
+ q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\
1707
+ q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
1708
+ MULTIPLY_ACCUM_WITH_SCALE((index));
1709
+
1710
+ for (int j = 0; j < QK_K/128; ++j) {
1711
+ const lm_ggml_uint8x16x2_t q2bits = lm_ggml_vld1q_u8_x2(q2); q2 += 32;
1712
+
1713
+ lm_ggml_int8x16x2_t q8bytes = lm_ggml_vld1q_s8_x2(q8); q8 += 32;
1714
+ q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
1715
+ q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
1716
+
1717
+ MULTIPLY_ACCUM_WITH_SCALE(0);
1718
+
1719
+ SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2);
1720
+ SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4);
1721
+ SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6);
1722
+
1723
+ is += 8;
1724
+ }
1725
+
1726
+ sum += d * isum;
1727
+ }
1728
+
1729
+ *s = sum;
1730
+
1731
+ #else
1732
+
1733
+ float sumf = 0;
1734
+
1735
+ for (int i = 0; i < nb; ++i) {
1736
+
1737
+ const uint8_t * q2 = x[i].qs;
1738
+ const int8_t * q8 = y[i].qs;
1739
+ const uint8_t * sc = x[i].scales;
1740
+
1741
+ int summs = 0;
1742
+ for (int j = 0; j < 16; ++j) {
1743
+ summs += y[i].bsums[j] * (sc[j] >> 4);
1744
+ }
1745
+
1746
+ const float dall = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
1747
+ const float dmin = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].dmin);
1748
+
1749
+ int isum = 0;
1750
+ int is = 0;
1751
+ int d;
1752
+ for (int k = 0; k < QK_K/128; ++k) {
1753
+ int shift = 0;
1754
+ for (int j = 0; j < 4; ++j) {
1755
+ d = sc[is++] & 0xF;
1756
+ int isuml = 0;
1757
+ for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
1758
+ isum += d * isuml;
1759
+ d = sc[is++] & 0xF;
1760
+ isuml = 0;
1761
+ for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
1762
+ isum += d * isuml;
1763
+ shift += 2;
1764
+ q8 += 32;
1765
+ }
1766
+ q2 += 32;
1767
+ }
1768
+ sumf += dall * isum - dmin * summs;
1769
+ }
1770
+ *s = sumf;
1771
+ #endif
1772
+ }
1773
+
1774
+ void lm_ggml_vec_dot_q3_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
1775
+ assert(n % QK_K == 0);
1776
+ assert(nrc == 1);
1777
+ UNUSED(nrc);
1778
+ UNUSED(bx);
1779
+ UNUSED(by);
1780
+ UNUSED(bs);
1781
+
1782
+ const uint32_t kmask1 = 0x03030303;
1783
+ const uint32_t kmask2 = 0x0f0f0f0f;
1784
+
1785
+ const block_q3_K * LM_GGML_RESTRICT x = vx;
1786
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
1787
+
1788
+ const int nb = n / QK_K;
1789
+
1790
+ #if defined(__ARM_FEATURE_SVE)
1791
+
1792
+ uint32_t aux[3];
1793
+ uint32_t utmp[4];
1794
+
1795
+ const int8_t m32 = 32;
1796
+ const int vector_length = svcntb()*8;
1797
+ const svuint8_t m3b_sv = svdup_n_u8(0x3);
1798
+ const svint32_t vzero_sv = svdup_n_s32(0);
1799
+
1800
+ const svuint8_t m0_sv = svdup_n_u8(1);
1801
+ const svuint8_t m1_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 1);
1802
+ const svuint8_t m2_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 2);
1803
+ const svuint8_t m3_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 3);
1804
+
1805
+ float sum = 0;
1806
+
1807
+ for (int i = 0; i < nb; ++i) {
1808
+
1809
+ const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
1810
+
1811
+ const uint8_t * LM_GGML_RESTRICT q3_sv = x[i].qs;
1812
+ const uint8_t * LM_GGML_RESTRICT qh_sv = x[i].hmask;
1813
+ const int8_t * LM_GGML_RESTRICT q8_sv = y[i].qs;
1814
+
1815
+ // Set up scales
1816
+ memcpy(aux, x[i].scales, 12);
1817
+ utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
1818
+ utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
1819
+ utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
1820
+ utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
1821
+
1822
+ int8_t * scale = (int8_t *)utmp;
1823
+
1824
+ for (int j = 0; j < 16; ++j) scale[j] -= m32;
1825
+
1826
+ switch (vector_length) {
1827
+ case 128:
1828
+ {
1829
+ svuint8_t qhbits_sv_1 = svld1_u8(svptrue_b8(), qh_sv);
1830
+ svuint8_t qhbits_sv_2 = svld1_u8(svptrue_b8(), qh_sv+16);
1831
+ svuint8_t q3h_sv;
1832
+
1833
+ svint32_t sumi1_1 = svdup_n_s32(0);
1834
+ svint8_t q3bytes_sv;
1835
+
1836
+ for (int j = 0; j < QK_K/128; ++j) {
1837
+
1838
+ const svuint8_t q3bits_sv = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;
1839
+ const svuint8_t q3bits_sv_1 = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;
1840
+ svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1841
+ svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1842
+
1843
+ q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_1), 2);
1844
+ q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1845
+
1846
+ sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));
1847
+
1848
+ q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_2), 2);
1849
+ q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv_1, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1850
+
1851
+ sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));
1852
+
1853
+ q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1854
+ q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1855
+
1856
+ q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_1), 1);
1857
+ q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1858
+
1859
+ sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));
1860
+
1861
+ q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_2), 1);
1862
+ q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1863
+
1864
+ sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));
1865
+
1866
+
1867
+ scale += 4;
1868
+ q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1869
+ q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1870
+
1871
+ q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_1);
1872
+ q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1873
+
1874
+ sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));
1875
+
1876
+ q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_2);
1877
+ q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1878
+
1879
+ sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));
1880
+
1881
+
1882
+ q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1883
+ q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
1884
+
1885
+ q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_1), 1);
1886
+ q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1887
+
1888
+ sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));
1889
+
1890
+ q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_2), 1);
1891
+ q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1892
+
1893
+ sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));
1894
+
1895
+ if (j == 0) {
1896
+ qhbits_sv_1 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_1, 4);
1897
+ qhbits_sv_2 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_2, 4);
1898
+ }
1899
+
1900
+ scale += 4;
1901
+ }
1902
+
1903
+ sum += d * (svaddv_s32(svptrue_b32(), sumi1_1));
1904
+ } break;
1905
+ case 256:
1906
+ case 512:
1907
+ {
1908
+ svuint8_t qhbits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), qh_sv);
1909
+ svuint8_t q3h_sv;
1910
+
1911
+ svint32_t sumi1_1 = svdup_n_s32(0);
1912
+ svint8_t q3bytes_sv;
1913
+
1914
+ for (int j = 0; j < QK_K/128; ++j) {
1915
+
1916
+ const svuint8_t q3bits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), q3_sv); q3_sv += 32;
1917
+ svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1918
+ svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1919
+
1920
+ q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m0_sv, qhbits_sv), 2);
1921
+ q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1922
+
1923
+
1924
+ svint32_t scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));
1925
+ sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);
1926
+
1927
+ q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m1_sv, qhbits_sv), 1);
1928
+ q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1929
+
1930
+ scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));
1931
+ sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);
1932
+
1933
+ scale += 4;
1934
+ q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1935
+ q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
1936
+
1937
+ q3h_sv = svbic_u8_x(svptrue_pat_b8(SV_VL32), m2_sv, qhbits_sv);
1938
+ q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1939
+
1940
+ scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));
1941
+ sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);
1942
+
1943
+ q3h_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m3_sv, qhbits_sv), 1);
1944
+ q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
1945
+
1946
+ scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));
1947
+ sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);
1948
+
1949
+ if (j == 0) {
1950
+ qhbits_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), qhbits_sv, 4);
1951
+ }
1952
+
1953
+ scale += 4;
1954
+ }
1955
+
1956
+ sum += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), sumi1_1));
1957
+ } break;
1958
+ default:
1959
+ assert(false && "Unsupported vector length");
1960
+ break;
1961
+ }
1962
+ }
1963
+ *s = sum;
1964
+
1965
+ #elif __ARM_NEON
1966
+
1967
+ uint32_t aux[3];
1968
+ uint32_t utmp[4];
1969
+
1970
+ const uint8x16_t m3b = vdupq_n_u8(0x3);
1971
+ const int32x4_t vzero = vdupq_n_s32(0);
1972
+
1973
+ const uint8x16_t m0 = vdupq_n_u8(1);
1974
+ const uint8x16_t m1 = vshlq_n_u8(m0, 1);
1975
+ const uint8x16_t m2 = vshlq_n_u8(m0, 2);
1976
+ const uint8x16_t m3 = vshlq_n_u8(m0, 3);
1977
+ const int8_t m32 = 32;
1978
+
1979
+ lm_ggml_int8x16x4_t q3bytes;
1980
+
1981
+ float sum = 0;
1982
+
1983
+ for (int i = 0; i < nb; ++i) {
1984
+
1985
+ const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
1986
+
1987
+ const uint8_t * LM_GGML_RESTRICT q3 = x[i].qs;
1988
+ const uint8_t * LM_GGML_RESTRICT qh = x[i].hmask;
1989
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
1990
+
1991
+ lm_ggml_uint8x16x2_t qhbits = lm_ggml_vld1q_u8_x2(qh);
1992
+
1993
+ lm_ggml_uint8x16x4_t q3h;
1994
+
1995
+ int32_t isum = 0;
1996
+
1997
+ // Set up scales
1998
+ memcpy(aux, x[i].scales, 12);
1999
+ utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
2000
+ utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
2001
+ utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
2002
+ utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
2003
+
2004
+ int8_t * scale = (int8_t *)utmp;
2005
+ for (int j = 0; j < 16; ++j) scale[j] -= m32;
2006
+
2007
+ for (int j = 0; j < QK_K/128; ++j) {
2008
+
2009
+ const lm_ggml_uint8x16x2_t q3bits = lm_ggml_vld1q_u8_x2(q3); q3 += 32;
2010
+ const lm_ggml_int8x16x4_t q8bytes_1 = lm_ggml_vld1q_s8_x4(q8); q8 += 64;
2011
+ const lm_ggml_int8x16x4_t q8bytes_2 = lm_ggml_vld1q_s8_x4(q8); q8 += 64;
2012
+
2013
+ q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2);
2014
+ q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2);
2015
+ q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1);
2016
+ q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1);
2017
+
2018
+ q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0]));
2019
+ q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1]));
2020
+ q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
2021
+ q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
2022
+
2023
+ isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
2024
+ isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
2025
+ isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
2026
+ isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
2027
+
2028
+ scale += 4;
2029
+
2030
+ q3h.val[0] = vbicq_u8(m2, qhbits.val[0]);
2031
+ q3h.val[1] = vbicq_u8(m2, qhbits.val[1]);
2032
+ q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1);
2033
+ q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1);
2034
+
2035
+ q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0]));
2036
+ q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1]));
2037
+ q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
2038
+ q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
2039
+
2040
+ isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
2041
+ isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
2042
+ isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
2043
+ isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
2044
+
2045
+ scale += 4;
2046
+
2047
+ if (j == 0) {
2048
+ qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4);
2049
+ qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4);
2050
+ }
2051
+
2052
+ }
2053
+ sum += d * isum;
2054
+
2055
+ }
2056
+
2057
+ *s = sum;
2058
+
2059
+ #else
2060
+ // scalar version
2061
+ // This function is written like this so the compiler can manage to vectorize most of it
2062
+ // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
2063
+ // manually vectorized version above. Every other version I tried would run at least 4 times slower.
2064
+ // The ideal situation would be if we could just write the code once, and the compiler would
2065
+ // automatically produce the best possible set of machine instructions, instead of us having to manually
2066
+ // write vectorized versions for AVX, ARM_NEON, etc.
2067
+
2068
+ int8_t aux8[QK_K];
2069
+ int16_t aux16[8];
2070
+ float sums [8];
2071
+ int32_t aux32[8];
2072
+ memset(sums, 0, 8*sizeof(float));
2073
+
2074
+ uint32_t auxs[4];
2075
+ const int8_t * scales = (const int8_t*)auxs;
2076
+
2077
+ float sumf = 0;
2078
+ for (int i = 0; i < nb; ++i) {
2079
+ const uint8_t * LM_GGML_RESTRICT q3 = x[i].qs;
2080
+ const uint8_t * LM_GGML_RESTRICT hm = x[i].hmask;
2081
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
2082
+ memset(aux32, 0, 8*sizeof(int32_t));
2083
+ int8_t * LM_GGML_RESTRICT a = aux8;
2084
+ uint8_t m = 1;
2085
+ for (int j = 0; j < QK_K; j += 128) {
2086
+ for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
2087
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
2088
+ a += 32; m <<= 1;
2089
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
2090
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
2091
+ a += 32; m <<= 1;
2092
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
2093
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
2094
+ a += 32; m <<= 1;
2095
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
2096
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
2097
+ a += 32; m <<= 1;
2098
+ q3 += 32;
2099
+ }
2100
+ a = aux8;
2101
+
2102
+ memcpy(auxs, x[i].scales, 12);
2103
+ uint32_t tmp = auxs[2];
2104
+ auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
2105
+ auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
2106
+ auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
2107
+ auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
2108
+ for (int j = 0; j < QK_K/16; ++j) {
2109
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2110
+ for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
2111
+ q8 += 8; a += 8;
2112
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2113
+ for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
2114
+ q8 += 8; a += 8;
2115
+ }
2116
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2117
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
2118
+ }
2119
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
2120
+ *s = sumf;
2121
+
2122
+ #endif
2123
+
2124
+ }
2125
+
2126
+ void lm_ggml_vec_dot_q4_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
2127
+ assert(n % QK_K == 0);
2128
+ #ifdef __ARM_FEATURE_MATMUL_INT8
2129
+ assert((nrc == 2) || (nrc == 1));
2130
+ #else
2131
+ assert(nrc == 1);
2132
+ #endif
2133
+ UNUSED(nrc);
2134
+ UNUSED(bx);
2135
+ UNUSED(by);
2136
+ UNUSED(bs);
2137
+
2138
+ const block_q4_K * LM_GGML_RESTRICT x = vx;
2139
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
2140
+
2141
+ const int nb = n / QK_K;
2142
+
2143
+ static const uint32_t kmask1 = 0x3f3f3f3f;
2144
+ static const uint32_t kmask2 = 0x0f0f0f0f;
2145
+ static const uint32_t kmask3 = 0x03030303;
2146
+
2147
+ uint32_t utmp[4];
2148
+
2149
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
2150
+ if (nrc == 2) {
2151
+ const block_q4_K * LM_GGML_RESTRICT x0 = x;
2152
+ const block_q4_K * LM_GGML_RESTRICT x1 = (const block_q4_K *) ((const uint8_t *)vx + bx);
2153
+ const block_q8_K * LM_GGML_RESTRICT y0 = y;
2154
+ const block_q8_K * LM_GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
2155
+
2156
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
2157
+
2158
+ float32x4_t vfsum = vdupq_n_f32(0.0f);
2159
+
2160
+ for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
2161
+ const uint8_t * LM_GGML_RESTRICT qx0 = x0->qs;
2162
+ const uint8_t * LM_GGML_RESTRICT qx1 = x1->qs;
2163
+ const int8_t * LM_GGML_RESTRICT qy0 = y0->qs;
2164
+ const int8_t * LM_GGML_RESTRICT qy1 = y1->qs;
2165
+
2166
+ // decode scales and mins
2167
+ int8_t x0_scales[8], x1_scales[8];
2168
+ int16x8_t x0_mins, x1_mins;
2169
+ {
2170
+ uint32_t scales_mins[3];
2171
+ memcpy(scales_mins, x0->scales, 12);
2172
+ const uint32_t mins_0_3 = scales_mins[1] & kmask1;
2173
+ const uint32_t mins_4_7 = ((scales_mins[2] >> 4) & kmask2) | (((scales_mins[1] >> 6) & kmask3) << 4);
2174
+ const uint32x2_t mins = {mins_0_3, mins_4_7};
2175
+ x0_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins)));
2176
+ uint32_t scales[2];
2177
+ scales[0] = scales_mins[0] & kmask1; // scales 0~3
2178
+ scales[1] = (scales_mins[2] & kmask2) | (((scales_mins[0] >> 6) & kmask3) << 4); // scales 4~7
2179
+ memcpy(x0_scales, scales, 8);
2180
+ }
2181
+ {
2182
+ uint32_t scales_mins[3];
2183
+ memcpy(scales_mins, x1->scales, 12);
2184
+ const uint32_t mins_0_3 = scales_mins[1] & kmask1;
2185
+ const uint32_t mins_4_7 = ((scales_mins[2] >> 4) & kmask2) | (((scales_mins[1] >> 6) & kmask3) << 4);
2186
+ const uint32x2_t mins = {mins_0_3, mins_4_7};
2187
+ x1_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins)));
2188
+ uint32_t scales[2];
2189
+ scales[0] = scales_mins[0] & kmask1; // scales 0~3
2190
+ scales[1] = (scales_mins[2] & kmask2) | (((scales_mins[0] >> 6) & kmask3) << 4); // scales 4~7
2191
+ memcpy(x1_scales, scales, 8);
2192
+ }
2193
+
2194
+ int32x4_t visum = {0};
2195
+
2196
+ // process 64 data points per iteration, totally 256 data points
2197
+ for (int j = 0; j < QK_K / 64; ++j, qx0 += 32, qx1 += 32, qy0 += 64, qy1 += 64) {
2198
+ const int8x16x4_t vy0 = vld1q_s8_x4(qy0);
2199
+ const int8x16x4_t vy1 = vld1q_s8_x4(qy1);
2200
+
2201
+ int8x16_t vx0[4], vx1[4];
2202
+ {
2203
+ const uint8x16x2_t vv = vld1q_u8_x2(qx0);
2204
+ vx0[0] = vreinterpretq_s8_u8(vandq_u8(vv.val[0], m4b));
2205
+ vx0[1] = vreinterpretq_s8_u8(vandq_u8(vv.val[1], m4b));
2206
+ vx0[2] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[0], 4));
2207
+ vx0[3] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[1], 4));
2208
+ }
2209
+ {
2210
+ const uint8x16x2_t vv = vld1q_u8_x2(qx1);
2211
+ vx1[0] = vreinterpretq_s8_u8(vandq_u8(vv.val[0], m4b));
2212
+ vx1[1] = vreinterpretq_s8_u8(vandq_u8(vv.val[1], m4b));
2213
+ vx1[2] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[0], 4));
2214
+ vx1[3] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[1], 4));
2215
+ }
2216
+
2217
+ // process 32 data points (share same block scale) per iteration
2218
+ for (int k = 0; k < 2; ++k) {
2219
+ const int blk = j * 2 + k;
2220
+ const int32x4_t block_scale = {
2221
+ x0_scales[blk],
2222
+ x0_scales[blk],
2223
+ x1_scales[blk],
2224
+ x1_scales[blk],
2225
+ };
2226
+
2227
+ int32x4_t vr = {0};
2228
+ for (int l = 0; l < 2; ++l) {
2229
+ const int idx = k * 2 + l;
2230
+ const int64x2_t vx0_s64 = vreinterpretq_s64_s8(vx0[idx]);
2231
+ const int64x2_t vx1_s64 = vreinterpretq_s64_s8(vx1[idx]);
2232
+ const int64x2_t vy0_s64 = vreinterpretq_s64_s8(vy0.val[idx]);
2233
+ const int64x2_t vy1_s64 = vreinterpretq_s64_s8(vy1.val[idx]);
2234
+ const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vx0_s64, vx1_s64));
2235
+ const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vx0_s64, vx1_s64));
2236
+ const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vy0_s64, vy1_s64));
2237
+ const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vy0_s64, vy1_s64));
2238
+ vr = vmmlaq_s32(vr, vx_l, vy_l);
2239
+ vr = vmmlaq_s32(vr, vx_h, vy_h);
2240
+ }
2241
+ // apply block scale, will NOT overflow
2242
+ // block_scale * sum_256(int4*int8) <= 2^(8+8+4+8) = 28 bits
2243
+ visum = vmlaq_s32(visum, vr, block_scale);
2244
+ }
2245
+ }
2246
+
2247
+ // adjust bias, apply superblock scale
2248
+ {
2249
+ int32_t bias[4];
2250
+ // no obvious uplift from sve sdot-16, just use neon mul add
2251
+ const int16x8_t y0_sums = vpaddq_s16(vld1q_s16(y0->bsums), vld1q_s16(y0->bsums+8));
2252
+ const int16x8_t y1_sums = vpaddq_s16(vld1q_s16(y1->bsums), vld1q_s16(y1->bsums+8));
2253
+ bias[0] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y0_sums), vget_low_s16(x0_mins)),
2254
+ vmull_s16(vget_high_s16(y0_sums), vget_high_s16(x0_mins))));
2255
+ bias[1] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y1_sums), vget_low_s16(x0_mins)),
2256
+ vmull_s16(vget_high_s16(y1_sums), vget_high_s16(x0_mins))));
2257
+ bias[2] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y0_sums), vget_low_s16(x1_mins)),
2258
+ vmull_s16(vget_high_s16(y0_sums), vget_high_s16(x1_mins))));
2259
+ bias[3] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y1_sums), vget_low_s16(x1_mins)),
2260
+ vmull_s16(vget_high_s16(y1_sums), vget_high_s16(x1_mins))));
2261
+ const float32x4_t dmins = {
2262
+ LM_GGML_CPU_FP16_TO_FP32(x0->dmin) * y0->d,
2263
+ LM_GGML_CPU_FP16_TO_FP32(x0->dmin) * y1->d,
2264
+ LM_GGML_CPU_FP16_TO_FP32(x1->dmin) * y0->d,
2265
+ LM_GGML_CPU_FP16_TO_FP32(x1->dmin) * y1->d,
2266
+ };
2267
+ vfsum = vmlsq_f32(vfsum, vcvtq_f32_s32(vld1q_s32(bias)), dmins);
2268
+
2269
+ const float32x4_t superblock_scale = {
2270
+ LM_GGML_CPU_FP16_TO_FP32(x0->d) * y0->d,
2271
+ LM_GGML_CPU_FP16_TO_FP32(x0->d) * y1->d,
2272
+ LM_GGML_CPU_FP16_TO_FP32(x1->d) * y0->d,
2273
+ LM_GGML_CPU_FP16_TO_FP32(x1->d) * y1->d,
2274
+ };
2275
+ vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
2276
+ }
2277
+ }
2278
+
2279
+ // vfsum = ABCD -> ACBD
2280
+ // AC -> s, BD -> (s+bs)
2281
+ vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
2282
+ vst1_f32(s, vget_low_f32 (vfsum));
2283
+ vst1_f32(s + bs, vget_high_f32(vfsum));
2284
+
2285
+ return;
2286
+ }
2287
+ #endif
2288
+
2289
+ #ifdef __ARM_FEATURE_SVE
2290
+ float sumf = 0;
2291
+ for (int i = 0; i < nb; ++i) {
2292
+
2293
+ const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
2294
+ const float dmin = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].dmin);
2295
+
2296
+ const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
2297
+
2298
+ memcpy(utmp, x[i].scales, K_SCALE_SIZE);
2299
+
2300
+ uint32x2_t mins8 = { 0 };
2301
+ mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
2302
+ mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
2303
+
2304
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
2305
+ utmp[0] &= kmask1;
2306
+
2307
+ const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
2308
+ const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
2309
+ vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
2310
+ sumf -= dmin * vaddvq_s32(prod);
2311
+
2312
+ const uint8_t * scales = (const uint8_t *)utmp;
2313
+
2314
+ const uint8_t * LM_GGML_RESTRICT q4 = x[i].qs;
2315
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
2316
+
2317
+ const int vector_length = lm_ggml_cpu_get_sve_cnt()*8;
2318
+ const svuint8_t m4b = svdup_n_u8(0xf);
2319
+ const svint32_t mzero = svdup_n_s32(0);
2320
+ svint32_t sumi1 = svdup_n_s32(0);
2321
+ svint32_t sumi1_1 = svdup_n_s32(0);
2322
+ svint32_t sumi1_2 = svdup_n_s32(0);
2323
+ svint32_t sumi2 = svdup_n_s32(0);
2324
+ svint32_t sumi2_1 = svdup_n_s32(0);
2325
+ svint32_t sumi2_2 = svdup_n_s32(0);
2326
+ switch (vector_length) {
2327
+ case 128:
2328
+ {
2329
+ for (int j = 0; j < QK_K/64; ++j) {
2330
+ svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b));
2331
+ svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
2332
+ sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
2333
+ q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b));
2334
+ q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
2335
+ sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
2336
+
2337
+ q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4));
2338
+ q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
2339
+ sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
2340
+ q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4));
2341
+ q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
2342
+ sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
2343
+ q4 += 32;
2344
+ }
2345
+ sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2);
2346
+ sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2);
2347
+ sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2)));
2348
+ } break;
2349
+ case 256:
2350
+ case 512:
2351
+ {
2352
+ for (int j = 0; j < QK_K/64; ++j) {
2353
+ const svuint8_t q4bits = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32;
2354
+ svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b));
2355
+ svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
2356
+ sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
2357
+
2358
+ q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4));
2359
+ q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
2360
+ sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
2361
+ }
2362
+ sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2)));
2363
+ } break;
2364
+ default:
2365
+ assert(false && "Unsupported vector length");
2366
+ break;
2367
+ }
2368
+ }
2369
+ *s = sumf;
2370
+ #elif defined __ARM_NEON
2371
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
2372
+ const int32x4_t mzero = vdupq_n_s32(0);
2373
+
2374
+ lm_ggml_int8x16x2_t q4bytes;
2375
+ lm_ggml_int8x16x2_t q8bytes;
2376
+
2377
+ float sumf = 0;
2378
+
2379
+ for (int i = 0; i < nb; ++i) {
2380
+
2381
+ const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
2382
+ const float dmin = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].dmin);
2383
+
2384
+ const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
2385
+
2386
+ memcpy(utmp, x[i].scales, 12);
2387
+
2388
+ uint32x2_t mins8 = { 0 };
2389
+ mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
2390
+ mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
2391
+
2392
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
2393
+ utmp[0] &= kmask1;
2394
+
2395
+ const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
2396
+ const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
2397
+ vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
2398
+ sumf -= dmin * vaddvq_s32(prod);
2399
+
2400
+ const uint8_t * scales = (const uint8_t *)utmp;
2401
+
2402
+ const uint8_t * LM_GGML_RESTRICT q4 = x[i].qs;
2403
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
2404
+
2405
+ int32_t sumi1 = 0;
2406
+ int32_t sumi2 = 0;
2407
+
2408
+ for (int j = 0; j < QK_K/64; ++j) {
2409
+ const lm_ggml_uint8x16x2_t q4bits = lm_ggml_vld1q_u8_x2(q4); q4 += 32;
2410
+
2411
+ q8bytes = lm_ggml_vld1q_s8_x2(q8); q8 += 32;
2412
+ q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
2413
+ q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
2414
+
2415
+ const int32x4_t p1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
2416
+ sumi1 += vaddvq_s32(p1) * scales[2*j+0];
2417
+
2418
+ q8bytes = lm_ggml_vld1q_s8_x2(q8); q8 += 32;
2419
+ q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
2420
+ q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
2421
+
2422
+ const int32x4_t p2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
2423
+
2424
+ sumi2 += vaddvq_s32(p2) * scales[2*j+1];
2425
+ }
2426
+
2427
+ sumf += d * (sumi1 + sumi2);
2428
+
2429
+ }
2430
+
2431
+ *s = sumf;
2432
+
2433
+ #else
2434
+
2435
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
2436
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
2437
+
2438
+ int8_t aux8[QK_K];
2439
+ int16_t aux16[8];
2440
+ float sums [8];
2441
+ int32_t aux32[8];
2442
+ memset(sums, 0, 8*sizeof(float));
2443
+
2444
+ float sumf = 0;
2445
+ for (int i = 0; i < nb; ++i) {
2446
+ const uint8_t * LM_GGML_RESTRICT q4 = x[i].qs;
2447
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
2448
+ memset(aux32, 0, 8*sizeof(int32_t));
2449
+ int8_t * LM_GGML_RESTRICT a = aux8;
2450
+ for (int j = 0; j < QK_K/64; ++j) {
2451
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
2452
+ a += 32;
2453
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
2454
+ a += 32; q4 += 32;
2455
+ }
2456
+ memcpy(utmp, x[i].scales, 12);
2457
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
2458
+ const uint32_t uaux = utmp[1] & kmask1;
2459
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
2460
+ utmp[2] = uaux;
2461
+ utmp[0] &= kmask1;
2462
+
2463
+ int sumi = 0;
2464
+ for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
2465
+ a = aux8;
2466
+ int is = 0;
2467
+ for (int j = 0; j < QK_K/32; ++j) {
2468
+ int32_t scale = scales[is++];
2469
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2470
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
2471
+ q8 += 8; a += 8;
2472
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2473
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
2474
+ q8 += 8; a += 8;
2475
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2476
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
2477
+ q8 += 8; a += 8;
2478
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2479
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
2480
+ q8 += 8; a += 8;
2481
+ }
2482
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2483
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
2484
+ const float dmin = LM_GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
2485
+ sumf -= dmin * sumi;
2486
+ }
2487
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
2488
+ *s = sumf;
2489
+ #endif
2490
+ }
2491
+
2492
+ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
2493
+ assert(n % QK_K == 0);
2494
+ assert(nrc == 1);
2495
+ UNUSED(nrc);
2496
+ UNUSED(bx);
2497
+ UNUSED(by);
2498
+ UNUSED(bs);
2499
+
2500
+ const block_q5_K * LM_GGML_RESTRICT x = vx;
2501
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
2502
+
2503
+ const int nb = n / QK_K;
2504
+
2505
+ static const uint32_t kmask1 = 0x3f3f3f3f;
2506
+ static const uint32_t kmask2 = 0x0f0f0f0f;
2507
+ static const uint32_t kmask3 = 0x03030303;
2508
+
2509
+ uint32_t utmp[4];
2510
+
2511
+
2512
+ #ifdef __ARM_NEON
2513
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
2514
+ const uint8x16_t mone = vdupq_n_u8(1);
2515
+ const uint8x16_t mtwo = vdupq_n_u8(2);
2516
+ const int32x4_t mzero = vdupq_n_s32(0);
2517
+
2518
+ lm_ggml_int8x16x4_t q5bytes;
2519
+
2520
+ float sumf = 0;
2521
+
2522
+ for (int i = 0; i < nb; ++i) {
2523
+
2524
+ const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
2525
+ const float dmin = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].dmin);
2526
+
2527
+ const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
2528
+
2529
+ memcpy(utmp, x[i].scales, 12);
2530
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
2531
+ const uint32_t uaux = utmp[1] & kmask1;
2532
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
2533
+ utmp[2] = uaux;
2534
+ utmp[0] &= kmask1;
2535
+
2536
+ const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8);
2537
+ const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8));
2538
+ const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
2539
+ vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
2540
+ int32_t sumi_mins = vaddvq_s32(prod);
2541
+
2542
+ const uint8_t * scales = (const uint8_t *)utmp;
2543
+
2544
+ const uint8_t * LM_GGML_RESTRICT q5 = x[i].qs;
2545
+ const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
2546
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
2547
+
2548
+ lm_ggml_uint8x16x2_t qhbits = lm_ggml_vld1q_u8_x2(qh);
2549
+
2550
+ lm_ggml_uint8x16x4_t q5h;
2551
+
2552
+ int32_t sumi = 0;
2553
+
2554
+ for (int j = 0; j < QK_K/64; ++j) {
2555
+
2556
+ const lm_ggml_uint8x16x2_t q5bits = lm_ggml_vld1q_u8_x2(q5); q5 += 32;
2557
+ const lm_ggml_int8x16x4_t q8bytes = lm_ggml_vld1q_s8_x4(q8); q8 += 64;
2558
+
2559
+ q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
2560
+ q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
2561
+ q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3);
2562
+ q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3);
2563
+ qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2);
2564
+ qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2);
2565
+
2566
+ q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0]));
2567
+ q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1]));
2568
+ q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
2569
+ q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
2570
+
2571
+ sumi += vaddvq_s32(lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
2572
+ sumi += vaddvq_s32(lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
2573
+ }
2574
+
2575
+ sumf += d * sumi - dmin * sumi_mins;
2576
+ }
2577
+
2578
+ *s = sumf;
2579
+
2580
+ #else
2581
+
2582
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
2583
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
2584
+
2585
+ int8_t aux8[QK_K];
2586
+ int16_t aux16[8];
2587
+ float sums [8];
2588
+ int32_t aux32[8];
2589
+ memset(sums, 0, 8*sizeof(float));
2590
+
2591
+ float sumf = 0;
2592
+ for (int i = 0; i < nb; ++i) {
2593
+ const uint8_t * LM_GGML_RESTRICT q4 = x[i].qs;
2594
+ const uint8_t * LM_GGML_RESTRICT hm = x[i].qh;
2595
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
2596
+ memset(aux32, 0, 8*sizeof(int32_t));
2597
+ int8_t * LM_GGML_RESTRICT a = aux8;
2598
+ uint8_t m = 1;
2599
+ for (int j = 0; j < QK_K/64; ++j) {
2600
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
2601
+ for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
2602
+ a += 32; m <<= 1;
2603
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
2604
+ for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
2605
+ a += 32; m <<= 1;
2606
+ q4 += 32;
2607
+ }
2608
+ memcpy(utmp, x[i].scales, 12);
2609
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
2610
+ const uint32_t uaux = utmp[1] & kmask1;
2611
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
2612
+ utmp[2] = uaux;
2613
+ utmp[0] &= kmask1;
2614
+
2615
+ int sumi = 0;
2616
+ for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
2617
+ a = aux8;
2618
+ int is = 0;
2619
+ for (int j = 0; j < QK_K/32; ++j) {
2620
+ int32_t scale = scales[is++];
2621
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2622
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
2623
+ q8 += 8; a += 8;
2624
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2625
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
2626
+ q8 += 8; a += 8;
2627
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2628
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
2629
+ q8 += 8; a += 8;
2630
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2631
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
2632
+ q8 += 8; a += 8;
2633
+ }
2634
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2635
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
2636
+ const float dmin = LM_GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
2637
+ sumf -= dmin * sumi;
2638
+ }
2639
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
2640
+ *s = sumf;
2641
+ #endif
2642
+ }
2643
+
2644
+ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
2645
+ assert(n % QK_K == 0);
2646
+ #ifdef __ARM_FEATURE_MATMUL_INT8
2647
+ assert((nrc == 2) || (nrc == 1));
2648
+ #else
2649
+ assert(nrc == 1);
2650
+ #endif
2651
+ UNUSED(nrc);
2652
+ UNUSED(bx);
2653
+ UNUSED(by);
2654
+ UNUSED(bs);
2655
+
2656
+ const block_q6_K * LM_GGML_RESTRICT x = vx;
2657
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
2658
+
2659
+ const int nb = n / QK_K;
2660
+
2661
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
2662
+ if (nrc == 2) {
2663
+ const block_q6_K * LM_GGML_RESTRICT x0 = x;
2664
+ const block_q6_K * LM_GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
2665
+ const block_q8_K * LM_GGML_RESTRICT y0 = y;
2666
+ const block_q8_K * LM_GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
2667
+
2668
+ float32x4_t vfsum = vdupq_n_f32(0.0f);
2669
+
2670
+ for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
2671
+ const uint8_t * LM_GGML_RESTRICT ql0 = x0->ql;
2672
+ const uint8_t * LM_GGML_RESTRICT ql1 = x1->ql;
2673
+ const uint8_t * LM_GGML_RESTRICT qh0 = x0->qh;
2674
+ const uint8_t * LM_GGML_RESTRICT qh1 = x1->qh;
2675
+ const int8_t * LM_GGML_RESTRICT qy0 = y0->qs;
2676
+ const int8_t * LM_GGML_RESTRICT qy1 = y1->qs;
2677
+
2678
+ const uint8x16_t mone = vdupq_n_u8(0x30);
2679
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
2680
+
2681
+ int32x4_t visum = vdupq_n_s32(0);
2682
+
2683
+ // process 8 blocks per iteration, totally 16 blocks
2684
+ for (int j = 0; j < 2; ++j, qh0 += 32, ql0 += 64, qh1 += 32, ql1 += 64) {
2685
+ int8x16_t vx0[8], vx1[8];
2686
+
2687
+ // de-quantize vx0[8]
2688
+ {
2689
+ const uint8x16x2_t qh_bits = vld1q_u8_x2(qh0);
2690
+ const uint8x16x4_t ql_bits = vld1q_u8_x4(ql0);
2691
+
2692
+ uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
2693
+ uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
2694
+ uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
2695
+ uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
2696
+
2697
+ vx0[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
2698
+ vx0[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
2699
+ vx0[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
2700
+ vx0[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
2701
+
2702
+ q6h_0 = vandq_u8(mone, qh_bits.val[0]);
2703
+ q6h_1 = vandq_u8(mone, qh_bits.val[1]);
2704
+ q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
2705
+ q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
2706
+
2707
+ vx0[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
2708
+ vx0[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
2709
+ vx0[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
2710
+ vx0[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
2711
+ }
2712
+
2713
+ // de-quantize vx1[8]
2714
+ {
2715
+ const uint8x16x2_t qh_bits = vld1q_u8_x2(qh1);
2716
+ const uint8x16x4_t ql_bits = vld1q_u8_x4(ql1);
2717
+
2718
+ uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
2719
+ uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
2720
+ uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
2721
+ uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
2722
+
2723
+ vx1[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
2724
+ vx1[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
2725
+ vx1[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
2726
+ vx1[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
2727
+
2728
+ q6h_0 = vandq_u8(mone, qh_bits.val[0]);
2729
+ q6h_1 = vandq_u8(mone, qh_bits.val[1]);
2730
+ q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
2731
+ q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
2732
+
2733
+ vx1[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
2734
+ vx1[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
2735
+ vx1[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
2736
+ vx1[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
2737
+ }
2738
+
2739
+ // process 16 elements (one block with same scale) per iteration
2740
+ // - vx = concat(ql, qh) - 32
2741
+ // - r1,r2,r3,r4 = smmla(vx, vy)
2742
+ for (int k = 0; k < 8; ++k) {
2743
+ const int blk = j * 8 + k;
2744
+
2745
+ const int8x16_t vy0 = vld1q_s8(qy0);
2746
+ const int8x16_t vy1 = vld1q_s8(qy1);
2747
+ qy0 += 16;
2748
+ qy1 += 16;
2749
+
2750
+ const int32x4_t block_scale = {
2751
+ x0->scales[blk],
2752
+ x0->scales[blk],
2753
+ x1->scales[blk],
2754
+ x1->scales[blk],
2755
+ };
2756
+
2757
+ // calculate four results at once with outer product
2758
+ const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
2759
+ const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
2760
+ const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
2761
+ const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
2762
+ int32x4_t vr = vdupq_n_s32(0);
2763
+ vr = vmmlaq_s32(vr, vx_l, vy_l);
2764
+ vr = vmmlaq_s32(vr, vx_h, vy_h);
2765
+
2766
+ // apply block scale, will NOT overflow
2767
+ // block_scale * sum_256(int6*int8) <= 2^(8+8+6+8) = 30 bits
2768
+ visum = vmlaq_s32(visum, vr, block_scale);
2769
+ }
2770
+ }
2771
+
2772
+ // adjust bias, apply superblock scale
2773
+ {
2774
+ int32_t bias[4];
2775
+ #ifdef __ARM_FEATURE_SVE
2776
+ const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
2777
+ const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8);
2778
+ const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums);
2779
+ const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8);
2780
+ const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums);
2781
+ const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8);
2782
+ const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales));
2783
+ const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8));
2784
+ const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales));
2785
+ const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8));
2786
+ const svint64_t zero = svdup_n_s64(0);
2787
+ bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0),
2788
+ svdot_s64(zero, y0_q8sums_1, x0_q6scales_1)));
2789
+ bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0),
2790
+ svdot_s64(zero, y1_q8sums_1, x0_q6scales_1)));
2791
+ bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0),
2792
+ svdot_s64(zero, y0_q8sums_1, x1_q6scales_1)));
2793
+ bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0),
2794
+ svdot_s64(zero, y1_q8sums_1, x1_q6scales_1)));
2795
+ #else
2796
+ // NEON doesn't support int16 dot product, fallback to separated mul and add
2797
+ const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
2798
+ const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
2799
+
2800
+ int8x16_t scales_s8 = vld1q_s8(x0->scales);
2801
+ const int16x8x2_t q6scales0 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
2802
+ scales_s8 = vld1q_s8(x1->scales);
2803
+ const int16x8x2_t q6scales1 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
2804
+
2805
+ int32x4_t prod;
2806
+ prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales0.val[0])),
2807
+ vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales0.val[0]))),
2808
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales0.val[1])),
2809
+ vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales0.val[1]))));
2810
+ bias[0] = vaddvq_s32(prod);
2811
+ prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales0.val[0])),
2812
+ vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales0.val[0]))),
2813
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales0.val[1])),
2814
+ vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales0.val[1]))));
2815
+ bias[1] = vaddvq_s32(prod);
2816
+ prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales1.val[0])),
2817
+ vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales1.val[0]))),
2818
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales1.val[1])),
2819
+ vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales1.val[1]))));
2820
+ bias[2] = vaddvq_s32(prod);
2821
+ prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales1.val[0])),
2822
+ vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales1.val[0]))),
2823
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales1.val[1])),
2824
+ vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
2825
+ bias[3] = vaddvq_s32(prod);
2826
+
2827
+ #endif
2828
+ const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
2829
+
2830
+ const float32x4_t superblock_scale = {
2831
+ LM_GGML_CPU_FP16_TO_FP32(x0->d) * y0->d,
2832
+ LM_GGML_CPU_FP16_TO_FP32(x0->d) * y1->d,
2833
+ LM_GGML_CPU_FP16_TO_FP32(x1->d) * y0->d,
2834
+ LM_GGML_CPU_FP16_TO_FP32(x1->d) * y1->d,
2835
+ };
2836
+
2837
+ visum = vsubq_s32(visum, vibias);
2838
+ vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
2839
+ }
2840
+ }
2841
+
2842
+ // vfsum = ABCD -> ACBD
2843
+ // AC -> s, BD -> (s+bs)
2844
+ vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
2845
+ vst1_f32(s, vget_low_f32 (vfsum));
2846
+ vst1_f32(s + bs, vget_high_f32(vfsum));
2847
+
2848
+ return;
2849
+ }
2850
+ #endif
2851
+
2852
+ #ifdef __ARM_FEATURE_SVE
2853
+ const int vector_length = lm_ggml_cpu_get_sve_cnt()*8;
2854
+ float sum = 0;
2855
+ svuint8_t m4b = svdup_n_u8(0xf);
2856
+ svint32_t vzero = svdup_n_s32(0);
2857
+ svuint8_t mone = svdup_n_u8(0x30);
2858
+ svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4;
2859
+ svuint8_t q6h_1, q6h_2, q6h_3, q6h_4;
2860
+
2861
+ for (int i = 0; i < nb; ++i) {
2862
+ const float d_all = LM_GGML_CPU_FP16_TO_FP32(x[i].d);
2863
+
2864
+ const uint8_t * LM_GGML_RESTRICT q6 = x[i].ql;
2865
+ const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
2866
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
2867
+
2868
+ const int8_t * LM_GGML_RESTRICT scale = x[i].scales;
2869
+
2870
+ const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
2871
+ const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums);
2872
+ const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8);
2873
+ const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale));
2874
+ const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8));
2875
+ const svint64_t prod = svdup_n_s64(0);
2876
+ int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1),
2877
+ svdot_s64(prod, q8sums_2, q6scales_2)));
2878
+ int32_t isum = 0;
2879
+
2880
+ switch (vector_length) {
2881
+ case 128:
2882
+ {
2883
+ const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
2884
+ const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
2885
+ svint32_t isum_tmp = svdup_n_s32(0);
2886
+ for (int j = 0; j < QK_K/128; ++j) {
2887
+ svuint8_t qhbits_1 = svld1_u8(pg8_16, qh);
2888
+ svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16);
2889
+ qh += 32;
2890
+ svuint8_t q6bits_1 = svld1_u8(pg8_16, q6);
2891
+ svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16);
2892
+ svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32);
2893
+ svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48);
2894
+ q6 += 64;
2895
+ svint8_t q8bytes_1 = svld1_s8(pg8_16, q8);
2896
+ svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16);
2897
+ svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32);
2898
+ svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48);
2899
+ q8 += 64;
2900
+
2901
+ q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4));
2902
+ q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4));
2903
+ q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2));
2904
+ q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2));
2905
+ q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1));
2906
+ q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2));
2907
+ q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3));
2908
+ q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4));
2909
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
2910
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
2911
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
2912
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
2913
+
2914
+ scale += 4;
2915
+ q8bytes_1 = svld1_s8(pg8_16, q8);
2916
+ q8bytes_2 = svld1_s8(pg8_16, q8+16);
2917
+ q8bytes_3 = svld1_s8(pg8_16, q8+32);
2918
+ q8bytes_4 = svld1_s8(pg8_16, q8+48);
2919
+ q8 += 64;
2920
+
2921
+ q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1);
2922
+ q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2);
2923
+ q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2));
2924
+ q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2));
2925
+ q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1));
2926
+ q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2));
2927
+ q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3));
2928
+ q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4));
2929
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
2930
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
2931
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
2932
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
2933
+ scale += 4;
2934
+ }
2935
+ isum += svaddv_s32(pg32_4, isum_tmp);
2936
+ sum += d_all * y[i].d * (isum - 32 * isum_mins);
2937
+ }
2938
+ break;
2939
+ case 256:
2940
+ case 512:
2941
+ {
2942
+ const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2);
2943
+ const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8);
2944
+ const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32);
2945
+ svint32_t isum_tmp = svdup_n_s32(0);
2946
+ for (int j = 0; j < QK_K/128; j++) {
2947
+ svuint8_t qhbits_1 = svld1_u8(pg8_32, qh);
2948
+ qh += 32;
2949
+ svuint8_t q6bits_1 = svld1_u8(pg8_32, q6);
2950
+ svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32);
2951
+ q6 += 64;
2952
+ svint8_t q8bytes_1 = svld1_s8(pg8_32, q8);
2953
+ svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32);
2954
+ svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64);
2955
+ svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96);
2956
+ q8 += 128;
2957
+ q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4));
2958
+ q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2));
2959
+ q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1);
2960
+ q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2));
2961
+ q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1));
2962
+ q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2));
2963
+ q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3));
2964
+ q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4));
2965
+
2966
+ svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale);
2967
+ scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
2968
+ scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
2969
+ svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2);
2970
+ scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
2971
+ scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
2972
+ svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4);
2973
+ scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
2974
+ scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
2975
+ svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6);
2976
+ scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
2977
+ scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
2978
+ svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp));
2979
+ svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp));
2980
+ svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp));
2981
+ svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp));
2982
+
2983
+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1);
2984
+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2);
2985
+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3);
2986
+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4);
2987
+ scale += 8;
2988
+ }
2989
+ isum += svaddv_s32(pg32_8, isum_tmp);
2990
+ sum += d_all * y[i].d * (isum - 32 * isum_mins);
2991
+ }
2992
+ break;
2993
+ default:
2994
+ assert(false && "Unsupported vector length");
2995
+ break;
2996
+ }
2997
+ }
2998
+
2999
+ *s = sum;
3000
+
3001
+ #elif __ARM_NEON
3002
+ float sum = 0;
3003
+
3004
+ const uint8x16_t m4b = vdupq_n_u8(0xF);
3005
+ const int32x4_t vzero = vdupq_n_s32(0);
3006
+ //const int8x16_t m32s = vdupq_n_s8(32);
3007
+
3008
+ const uint8x16_t mone = vdupq_n_u8(3);
3009
+
3010
+ lm_ggml_int8x16x4_t q6bytes;
3011
+ lm_ggml_uint8x16x4_t q6h;
3012
+
3013
+ for (int i = 0; i < nb; ++i) {
3014
+
3015
+ const float d_all = LM_GGML_CPU_FP16_TO_FP32(x[i].d);
3016
+
3017
+ const uint8_t * LM_GGML_RESTRICT q6 = x[i].ql;
3018
+ const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
3019
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
3020
+
3021
+ const int8_t * LM_GGML_RESTRICT scale = x[i].scales;
3022
+
3023
+ const lm_ggml_int16x8x2_t q8sums = lm_ggml_vld1q_s16_x2(y[i].bsums);
3024
+ const int8x16_t scales = vld1q_s8(scale);
3025
+ const lm_ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}};
3026
+
3027
+ const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
3028
+ vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
3029
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])),
3030
+ vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1]))));
3031
+ int32_t isum_mins = vaddvq_s32(prod);
3032
+
3033
+ int32_t isum = 0;
3034
+
3035
+ for (int j = 0; j < QK_K/128; ++j) {
3036
+
3037
+ lm_ggml_uint8x16x2_t qhbits = lm_ggml_vld1q_u8_x2(qh); qh += 32;
3038
+ lm_ggml_uint8x16x4_t q6bits = lm_ggml_vld1q_u8_x4(q6); q6 += 64;
3039
+ lm_ggml_int8x16x4_t q8bytes = lm_ggml_vld1q_s8_x4(q8); q8 += 64;
3040
+
3041
+ q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
3042
+ q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
3043
+ uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2);
3044
+ q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
3045
+ shifted = vshrq_n_u8(qhbits.val[1], 2);
3046
+ q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
3047
+
3048
+ //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
3049
+ //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
3050
+ //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s);
3051
+ //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s);
3052
+ q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0]));
3053
+ q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1]));
3054
+ q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
3055
+ q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
3056
+
3057
+ isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
3058
+ vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
3059
+ vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
3060
+ vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
3061
+
3062
+ scale += 4;
3063
+
3064
+ q8bytes = lm_ggml_vld1q_s8_x4(q8); q8 += 64;
3065
+
3066
+ shifted = vshrq_n_u8(qhbits.val[0], 4);
3067
+ q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
3068
+ shifted = vshrq_n_u8(qhbits.val[1], 4);
3069
+ q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
3070
+ shifted = vshrq_n_u8(qhbits.val[0], 6);
3071
+ q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
3072
+ shifted = vshrq_n_u8(qhbits.val[1], 6);
3073
+ q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
3074
+
3075
+ //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s);
3076
+ //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s);
3077
+ //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s);
3078
+ //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s);
3079
+ q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0]));
3080
+ q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1]));
3081
+ q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
3082
+ q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
3083
+
3084
+ isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
3085
+ vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
3086
+ vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
3087
+ vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
3088
+ scale += 4;
3089
+ }
3090
+ //sum += isum * d_all * y[i].d;
3091
+ sum += d_all * y[i].d * (isum - 32 * isum_mins);
3092
+
3093
+ }
3094
+ *s = sum;
3095
+ #else
3096
+
3097
+ int8_t aux8[QK_K];
3098
+ int16_t aux16[8];
3099
+ float sums [8];
3100
+ int32_t aux32[8];
3101
+ memset(sums, 0, 8*sizeof(float));
3102
+
3103
+ float sumf = 0;
3104
+ for (int i = 0; i < nb; ++i) {
3105
+ const uint8_t * LM_GGML_RESTRICT q4 = x[i].ql;
3106
+ const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
3107
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
3108
+ memset(aux32, 0, 8*sizeof(int32_t));
3109
+ int8_t * LM_GGML_RESTRICT a = aux8;
3110
+ for (int j = 0; j < QK_K; j += 128) {
3111
+ for (int l = 0; l < 32; ++l) {
3112
+ a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
3113
+ a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
3114
+ a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
3115
+ a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
3116
+ }
3117
+ a += 128;
3118
+ q4 += 64;
3119
+ qh += 32;
3120
+ }
3121
+ a = aux8;
3122
+ int is = 0;
3123
+ for (int j = 0; j < QK_K/16; ++j) {
3124
+ int scale = x[i].scales[is++];
3125
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
3126
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
3127
+ q8 += 8; a += 8;
3128
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
3129
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
3130
+ q8 += 8; a += 8;
3131
+ }
3132
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3133
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
3134
+ }
3135
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
3136
+ *s = sumf;
3137
+ #endif
3138
+ }
3139
+
3140
+ #if defined (__ARM_NEON)
3141
+ static const int8_t keven_signs_q2xs[1024] = {
3142
+ 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
3143
+ 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
3144
+ 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1,
3145
+ 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1,
3146
+ 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1,
3147
+ 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1,
3148
+ 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1,
3149
+ 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1,
3150
+ 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1,
3151
+ 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1,
3152
+ 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1,
3153
+ 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1,
3154
+ 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1,
3155
+ 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1,
3156
+ 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1,
3157
+ 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1,
3158
+ 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1,
3159
+ 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1,
3160
+ 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1,
3161
+ 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1,
3162
+ 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1,
3163
+ 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1,
3164
+ 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1,
3165
+ 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1,
3166
+ 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1,
3167
+ 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1,
3168
+ 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1,
3169
+ 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1,
3170
+ 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1,
3171
+ 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1,
3172
+ 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
3173
+ 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
3174
+ };
3175
+ #endif
3176
+
3177
+ void lm_ggml_vec_dot_iq2_xxs_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
3178
+ assert(n % QK_K == 0);
3179
+ assert(nrc == 1);
3180
+ UNUSED(nrc);
3181
+ UNUSED(bx);
3182
+ UNUSED(by);
3183
+ UNUSED(bs);
3184
+
3185
+ const block_iq2_xxs * LM_GGML_RESTRICT x = vx;
3186
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
3187
+
3188
+ const int nb = n / QK_K;
3189
+
3190
+ #if defined(__ARM_NEON)
3191
+
3192
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
3193
+
3194
+ uint32_t aux32[4];
3195
+ const uint8_t * aux8 = (const uint8_t *)aux32;
3196
+
3197
+ lm_ggml_int8x16x4_t q2u;
3198
+ lm_ggml_int8x16x4_t q2s;
3199
+ lm_ggml_int8x16x4_t q8b;
3200
+
3201
+ float sumf = 0;
3202
+ for (int i = 0; i < nb; ++i) {
3203
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3204
+ const uint16_t * LM_GGML_RESTRICT q2 = x[i].qs;
3205
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
3206
+ float sumf1 = 0, sumf2 = 0;
3207
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3208
+ q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64;
3209
+ memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
3210
+ q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1])));
3211
+ q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3])));
3212
+ q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9])));
3213
+ q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11])));
3214
+ q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127))));
3215
+ q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
3216
+ q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 7) & 127))));
3217
+ q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127))));
3218
+ q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
3219
+ q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
3220
+ q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
3221
+ q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
3222
+ const int32x4_t p1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]);
3223
+ const int32x4_t p2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]);
3224
+ sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28));
3225
+ sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28));
3226
+ }
3227
+ sumf += d*(sumf1 + sumf2);
3228
+ }
3229
+ *s = 0.25f * sumf;
3230
+
3231
+ #else
3232
+
3233
+ uint32_t aux32[2];
3234
+ const uint8_t * aux8 = (const uint8_t *)aux32;
3235
+
3236
+ float sumf = 0.f;
3237
+ for (int i = 0; i < nb; ++i) {
3238
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3239
+ const uint16_t * LM_GGML_RESTRICT q2 = x[i].qs;
3240
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
3241
+ int32_t bsum = 0;
3242
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
3243
+ memcpy(aux32, q2, 2*sizeof(uint32_t));
3244
+ q2 += 4;
3245
+ const uint32_t ls = 2*(aux32[1] >> 28) + 1;
3246
+ int32_t sumi = 0;
3247
+ for (int l = 0; l < 4; ++l) {
3248
+ const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
3249
+ const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
3250
+ for (int j = 0; j < 8; ++j) {
3251
+ sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
3252
+ }
3253
+ q8 += 8;
3254
+ }
3255
+ bsum += sumi * ls;
3256
+ }
3257
+ sumf += d * bsum;
3258
+ }
3259
+ *s = 0.125f * sumf;
3260
+ #endif
3261
+ }
3262
+
3263
+ void lm_ggml_vec_dot_iq2_xs_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
3264
+ assert(n % QK_K == 0);
3265
+ assert(nrc == 1);
3266
+ UNUSED(nrc);
3267
+ UNUSED(bx);
3268
+ UNUSED(by);
3269
+ UNUSED(bs);
3270
+
3271
+ const block_iq2_xs * LM_GGML_RESTRICT x = vx;
3272
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
3273
+
3274
+ const int nb = n / QK_K;
3275
+
3276
+ #if defined(__ARM_NEON)
3277
+
3278
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
3279
+
3280
+ lm_ggml_int8x16x4_t q2u;
3281
+ lm_ggml_int8x16x4_t q2s;
3282
+ lm_ggml_int8x16x4_t q8b;
3283
+
3284
+ int32x4x4_t scales32;
3285
+
3286
+ float sumf = 0;
3287
+ for (int i = 0; i < nb; ++i) {
3288
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3289
+ const uint16_t * LM_GGML_RESTRICT q2 = x[i].qs;
3290
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
3291
+ const uint8x8_t scales8 = vld1_u8(x[i].scales);
3292
+ const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf));
3293
+ const uint8x8_t scales_h = vshr_n_u8(scales8, 4);
3294
+ uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));
3295
+ scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1));
3296
+ const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales));
3297
+ const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales));
3298
+ scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1)));
3299
+ scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1)));
3300
+ scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2)));
3301
+ scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2)));
3302
+ int32x4_t sumi = vdupq_n_s32(0);
3303
+ for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
3304
+ q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64;
3305
+ q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511))));
3306
+ q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511))));
3307
+ q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511))));
3308
+ q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[6] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[7] & 511))));
3309
+ q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[0] >> 9))), vld1_s8((const void *)(signs64 + (q2[1] >> 9))));
3310
+ q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[2] >> 9))), vld1_s8((const void *)(signs64 + (q2[3] >> 9))));
3311
+ q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[4] >> 9))), vld1_s8((const void *)(signs64 + (q2[5] >> 9))));
3312
+ q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[6] >> 9))), vld1_s8((const void *)(signs64 + (q2[7] >> 9))));
3313
+ q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
3314
+ q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
3315
+ q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
3316
+ q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
3317
+ const int32x4_t p1 = lm_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]);
3318
+ const int32x4_t p2 = lm_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]);
3319
+ const int32x4_t p3 = lm_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]);
3320
+ const int32x4_t p4 = lm_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]);
3321
+ const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4));
3322
+ sumi = vmlaq_s32(sumi, p, scales32.val[ib64]);
3323
+ q2 += 8;
3324
+ }
3325
+ sumf += d*vaddvq_s32(sumi);
3326
+ }
3327
+ *s = 0.125f * sumf;
3328
+
3329
+ #else
3330
+
3331
+ float sumf = 0.f;
3332
+ for (int i = 0; i < nb; ++i) {
3333
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3334
+ const uint16_t * LM_GGML_RESTRICT q2 = x[i].qs;
3335
+ const uint8_t * LM_GGML_RESTRICT sc = x[i].scales;
3336
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
3337
+ int32_t bsum = 0;
3338
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
3339
+ const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1;
3340
+ const uint16_t ls2 = 2*(sc[ib32] >> 4) + 1;
3341
+ int32_t sumi = 0;
3342
+ for (int l = 0; l < 2; ++l) {
3343
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
3344
+ const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
3345
+ for (int j = 0; j < 8; ++j) {
3346
+ sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
3347
+ }
3348
+ q8 += 8;
3349
+ }
3350
+ bsum += sumi * ls1;
3351
+ sumi = 0;
3352
+ for (int l = 2; l < 4; ++l) {
3353
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
3354
+ const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
3355
+ for (int j = 0; j < 8; ++j) {
3356
+ sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
3357
+ }
3358
+ q8 += 8;
3359
+ }
3360
+ bsum += sumi * ls2;
3361
+ q2 += 4;
3362
+ }
3363
+ sumf += d * bsum;
3364
+ }
3365
+ *s = 0.125f * sumf;
3366
+ #endif
3367
+ }
3368
+
3369
+ void lm_ggml_vec_dot_iq2_s_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
3370
+ assert(n % QK_K == 0);
3371
+ assert(nrc == 1);
3372
+ UNUSED(nrc);
3373
+ UNUSED(bx);
3374
+ UNUSED(by);
3375
+ UNUSED(bs);
3376
+
3377
+ const block_iq2_s * LM_GGML_RESTRICT x = vx;
3378
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
3379
+
3380
+ const int nb = n / QK_K;
3381
+
3382
+ #if defined(__ARM_NEON)
3383
+
3384
+ static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
3385
+ 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
3386
+ };
3387
+
3388
+ static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
3389
+
3390
+ const lm_ggml_uint8x16x2_t mask1 = lm_ggml_vld1q_u8_x2(k_mask1);
3391
+ const uint8x16_t mask2 = vld1q_u8(k_mask2);
3392
+ const uint8x16_t m1 = vdupq_n_u8(1);
3393
+ const int32x4_t vzero = vdupq_n_s32(0);
3394
+
3395
+ uint8x16x2_t vs;
3396
+ lm_ggml_int8x16x4_t q2s;
3397
+ lm_ggml_int8x16x4_t q8b;
3398
+
3399
+ float sumf = 0;
3400
+ for (int i = 0; i < nb; ++i) {
3401
+
3402
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3403
+
3404
+ const uint8_t * LM_GGML_RESTRICT qs = x[i].qs;
3405
+ const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
3406
+ const uint16_t * LM_GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8);
3407
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
3408
+
3409
+ int sumi1 = 0, sumi2 = 0;
3410
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3411
+ q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64;
3412
+ q2s.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[0] | ((qh[ib32+0] << 8) & 0x300)))),
3413
+ vld1_s8((const int8_t *)(iq2s_grid + (qs[1] | ((qh[ib32+0] << 6) & 0x300)))));
3414
+ q2s.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[2] | ((qh[ib32+0] << 4) & 0x300)))),
3415
+ vld1_s8((const int8_t *)(iq2s_grid + (qs[3] | ((qh[ib32+0] << 2) & 0x300)))));
3416
+ q2s.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[4] | ((qh[ib32+1] << 8) & 0x300)))),
3417
+ vld1_s8((const int8_t *)(iq2s_grid + (qs[5] | ((qh[ib32+1] << 6) & 0x300)))));
3418
+ q2s.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[6] | ((qh[ib32+1] << 4) & 0x300)))),
3419
+ vld1_s8((const int8_t *)(iq2s_grid + (qs[7] | ((qh[ib32+1] << 2) & 0x300)))));
3420
+ qs += 8;
3421
+
3422
+ vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16)));
3423
+ vs.val[1] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
3424
+ vs.val[0] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
3425
+ vs.val[0] = vceqq_u8(vs.val[0], mask2);
3426
+ vs.val[1] = vceqq_u8(vs.val[1], mask2);
3427
+
3428
+ q2s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[0]);
3429
+ q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]);
3430
+
3431
+ vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16)));
3432
+ vs.val[1] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
3433
+ vs.val[0] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
3434
+ vs.val[0] = vceqq_u8(vs.val[0], mask2);
3435
+ vs.val[1] = vceqq_u8(vs.val[1], mask2);
3436
+
3437
+ signs += 4;
3438
+
3439
+ q2s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[2]);
3440
+ q2s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[3]);
3441
+
3442
+ const int32x4_t p1 = lm_ggml_vdotq_s32(vzero, q2s.val[0], q8b.val[0]);
3443
+ const int32x4_t p2 = lm_ggml_vdotq_s32(vzero, q2s.val[1], q8b.val[1]);
3444
+ const int32x4_t p3 = lm_ggml_vdotq_s32(vzero, q2s.val[2], q8b.val[2]);
3445
+ const int32x4_t p4 = lm_ggml_vdotq_s32(vzero, q2s.val[3], q8b.val[3]);
3446
+
3447
+ sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32+0] & 0xf));
3448
+ sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32+0] >> 4));
3449
+ sumi1 += vaddvq_s32(p3) * (1 + 2*(x[i].scales[ib32+1] & 0xf));
3450
+ sumi2 += vaddvq_s32(p4) * (1 + 2*(x[i].scales[ib32+1] >> 4));
3451
+ }
3452
+ sumf += d*(sumi1 + sumi2);
3453
+ }
3454
+
3455
+ *s = 0.125f * sumf;
3456
+
3457
+ #else
3458
+
3459
+ float sumf = 0;
3460
+ for (int i = 0; i < nb; i++) {
3461
+
3462
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3463
+ const int8_t * q8 = y[i].qs;
3464
+ const uint8_t * qs = x[i].qs;
3465
+ const uint8_t * qh = x[i].qh;
3466
+ const uint8_t * signs = qs + QK_K/8;
3467
+
3468
+ int bsum = 0;
3469
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
3470
+ int ls1 = 1 + 2*(x[i].scales[ib32] & 0xf);
3471
+ int ls2 = 1 + 2*(x[i].scales[ib32] >> 4);
3472
+ int sumi1 = 0, sumi2 = 0;
3473
+ for (int l = 0; l < 2; ++l) {
3474
+ const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));
3475
+ for (int j = 0; j < 8; ++j) {
3476
+ sumi1 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1);
3477
+ }
3478
+ q8 += 8;
3479
+ }
3480
+ for (int l = 2; l < 4; ++l) {
3481
+ const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));
3482
+ for (int j = 0; j < 8; ++j) {
3483
+ sumi2 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1);
3484
+ }
3485
+ q8 += 8;
3486
+ }
3487
+ bsum += ls1 * sumi1 + ls2 * sumi2;
3488
+ qs += 4;
3489
+ signs += 4;
3490
+ }
3491
+
3492
+ sumf += d * bsum;
3493
+ }
3494
+
3495
+ *s = 0.125f * sumf;
3496
+
3497
+ #endif
3498
+
3499
+ }
3500
+
3501
+ void lm_ggml_vec_dot_iq3_xxs_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
3502
+ assert(n % QK_K == 0);
3503
+ assert(nrc == 1);
3504
+ UNUSED(nrc);
3505
+ UNUSED(bx);
3506
+ UNUSED(by);
3507
+ UNUSED(bs);
3508
+
3509
+ const block_iq3_xxs * LM_GGML_RESTRICT x = vx;
3510
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
3511
+
3512
+ const int nb = n / QK_K;
3513
+
3514
+ #if defined(__ARM_NEON)
3515
+
3516
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
3517
+
3518
+ uint32_t aux32[2];
3519
+
3520
+ lm_ggml_int8x16x4_t q3s;
3521
+ lm_ggml_int8x16x4_t q8b;
3522
+
3523
+ float sumf = 0;
3524
+ for (int i = 0; i < nb; ++i) {
3525
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3526
+ const uint8_t * LM_GGML_RESTRICT q3 = x[i].qs;
3527
+ const uint8_t * LM_GGML_RESTRICT gas = x[i].qs + QK_K/4;
3528
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
3529
+ float sumf1 = 0, sumf2 = 0;
3530
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3531
+ q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64;
3532
+ memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t);
3533
+ const uint32x4_t aux32x4_0 = lm_ggml_vld1q_u32(iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]);
3534
+ const uint32x4_t aux32x4_1 = lm_ggml_vld1q_u32(iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]);
3535
+ const uint32x4_t aux32x4_2 = lm_ggml_vld1q_u32(iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]);
3536
+ const uint32x4_t aux32x4_3 = lm_ggml_vld1q_u32(iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]);
3537
+ q3 += 16;
3538
+ q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127))));
3539
+ q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127))));
3540
+ q3s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127))));
3541
+ q3s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
3542
+ q3s.val[0] = vmulq_s8(q3s.val[0], vreinterpretq_s8_u32(aux32x4_0));
3543
+ q3s.val[1] = vmulq_s8(q3s.val[1], vreinterpretq_s8_u32(aux32x4_1));
3544
+ q3s.val[2] = vmulq_s8(q3s.val[2], vreinterpretq_s8_u32(aux32x4_2));
3545
+ q3s.val[3] = vmulq_s8(q3s.val[3], vreinterpretq_s8_u32(aux32x4_3));
3546
+ const int32x4_t p1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
3547
+ const int32x4_t p2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
3548
+ sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[0] >> 28));
3549
+ sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[1] >> 28));
3550
+ }
3551
+ sumf += d*(sumf1 + sumf2);
3552
+ }
3553
+ *s = 0.5f * sumf;
3554
+
3555
+ #else
3556
+
3557
+ uint32_t aux32;
3558
+
3559
+ float sumf = 0.f;
3560
+ for (int i = 0; i < nb; ++i) {
3561
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3562
+ const uint8_t * LM_GGML_RESTRICT q3 = x[i].qs;
3563
+ const uint8_t * LM_GGML_RESTRICT gas = x[i].qs + QK_K/4;
3564
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
3565
+ int32_t bsum = 0;
3566
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
3567
+ memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t);
3568
+ const uint32_t ls = 2*(aux32 >> 28) + 1;
3569
+ int32_t sumi = 0;
3570
+ for (int l = 0; l < 4; ++l) {
3571
+ const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]);
3572
+ const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]);
3573
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
3574
+ for (int j = 0; j < 4; ++j) {
3575
+ sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1);
3576
+ sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1);
3577
+ }
3578
+ q8 += 8;
3579
+ }
3580
+ q3 += 8;
3581
+ bsum += sumi * ls;
3582
+ }
3583
+ sumf += d * bsum;
3584
+ }
3585
+ *s = 0.25f * sumf;
3586
+ #endif
3587
+ }
3588
+
3589
+ void lm_ggml_vec_dot_iq3_s_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
3590
+ assert(n % QK_K == 0);
3591
+ assert(nrc == 1);
3592
+ UNUSED(nrc);
3593
+ UNUSED(bx);
3594
+ UNUSED(by);
3595
+ UNUSED(bs);
3596
+
3597
+ const block_iq3_s * LM_GGML_RESTRICT x = vx;
3598
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
3599
+
3600
+ const int nb = n / QK_K;
3601
+
3602
+ #if defined(__ARM_NEON)
3603
+
3604
+ typedef union {
3605
+ uint16x8_t vec_index;
3606
+ uint16_t index[8];
3607
+ } vec_index_t;
3608
+
3609
+ static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
3610
+ 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
3611
+ };
3612
+
3613
+ static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
3614
+
3615
+ static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1};
3616
+
3617
+ const lm_ggml_uint8x16x2_t mask1 = lm_ggml_vld1q_u8_x2(k_mask1);
3618
+ const uint8x16_t mask2 = vld1q_u8(k_mask2);
3619
+
3620
+ const int16x8_t hshift = vld1q_s16(k_shift);
3621
+ const uint16x8_t m256 = vdupq_n_u16(256);
3622
+ const uint8x16_t m1 = vdupq_n_u8(1);
3623
+
3624
+ uint8x16x2_t vs;
3625
+ lm_ggml_int8x16x4_t q3s;
3626
+ lm_ggml_int8x16x4_t q8b;
3627
+ vec_index_t idx;
3628
+
3629
+ uint32_t scales32[2];
3630
+ const uint8_t * scales8 = (const uint8_t *)scales32;
3631
+
3632
+ float sumf = 0;
3633
+ for (int i = 0; i < nb; ++i) {
3634
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3635
+ const uint8_t * LM_GGML_RESTRICT qs = x[i].qs;
3636
+ const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
3637
+ const uint16_t * LM_GGML_RESTRICT signs = (const uint16_t *)x[i].signs;
3638
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
3639
+
3640
+ memcpy(scales32, x[i].scales, 4);
3641
+ scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101;
3642
+ scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101;
3643
+
3644
+ int sumi1 = 0, sumi2 = 0;
3645
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3646
+ q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64;
3647
+
3648
+ const uint8x16_t idx_l = vld1q_u8(qs); qs += 16;
3649
+ idx.vec_index = vorrq_u16(vmovl_u8(vget_low_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+0]), hshift), m256));
3650
+ const uint32x4_t aux32x4_0 = lm_ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]],
3651
+ iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]);
3652
+ const uint32x4_t aux32x4_1 = lm_ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]],
3653
+ iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]);
3654
+ idx.vec_index = vorrq_u16(vmovl_u8(vget_high_u8(idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), hshift), m256));
3655
+ const uint32x4_t aux32x4_2 = lm_ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]],
3656
+ iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]);
3657
+ const uint32x4_t aux32x4_3 = lm_ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]],
3658
+ iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]);
3659
+
3660
+
3661
+ vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16)));
3662
+ vs.val[1] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
3663
+ vs.val[0] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
3664
+ vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
3665
+ vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1);
3666
+
3667
+ q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0));
3668
+ q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1));
3669
+
3670
+ vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16)));
3671
+ vs.val[1] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
3672
+ vs.val[0] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
3673
+ vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
3674
+ vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1);
3675
+
3676
+ signs += 4;
3677
+
3678
+ q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_2));
3679
+ q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_3));
3680
+
3681
+ const int32x4_t p1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
3682
+ const int32x4_t p2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
3683
+
3684
+ sumi1 += vaddvq_s32(p1) * scales8[ib32/2+0];
3685
+ sumi2 += vaddvq_s32(p2) * scales8[ib32/2+4];
3686
+ }
3687
+ sumf += d*(sumi1 + sumi2);
3688
+ }
3689
+ *s = sumf;
3690
+
3691
+ #else
3692
+
3693
+ float sumf = 0.f;
3694
+ for (int i = 0; i < nb; ++i) {
3695
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3696
+ const uint8_t * LM_GGML_RESTRICT qs = x[i].qs;
3697
+ const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
3698
+ const uint8_t * LM_GGML_RESTRICT signs = x[i].signs;
3699
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
3700
+ int32_t bsum = 0;
3701
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3702
+ const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1;
3703
+ const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1;
3704
+ int32_t sumi = 0;
3705
+ for (int l = 0; l < 4; ++l) {
3706
+ const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256)));
3707
+ const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256)));
3708
+ for (int j = 0; j < 4; ++j) {
3709
+ sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
3710
+ sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
3711
+ }
3712
+ q8 += 8;
3713
+ }
3714
+ qs += 8;
3715
+ signs += 4;
3716
+ bsum += sumi * ls1;
3717
+ sumi = 0;
3718
+ for (int l = 0; l < 4; ++l) {
3719
+ const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256)));
3720
+ const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256)));
3721
+ for (int j = 0; j < 4; ++j) {
3722
+ sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
3723
+ sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
3724
+ }
3725
+ q8 += 8;
3726
+ }
3727
+ qs += 8;
3728
+ signs += 4;
3729
+ bsum += sumi * ls2;
3730
+ }
3731
+ sumf += d * bsum;
3732
+ }
3733
+ *s = sumf;
3734
+ #endif
3735
+ }
3736
+
3737
+ void lm_ggml_vec_dot_iq1_s_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
3738
+ assert(n % QK_K == 0);
3739
+ assert(nrc == 1);
3740
+ UNUSED(nrc);
3741
+ UNUSED(bx);
3742
+ UNUSED(by);
3743
+ UNUSED(bs);
3744
+
3745
+ const block_iq1_s * LM_GGML_RESTRICT x = vx;
3746
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
3747
+
3748
+ const int nb = n / QK_K;
3749
+
3750
+ #if defined __ARM_NEON
3751
+
3752
+ lm_ggml_int8x16x4_t q1b;
3753
+ lm_ggml_int8x16x4_t q8b;
3754
+
3755
+ float sumf = 0;
3756
+ for (int i = 0; i < nb; ++i) {
3757
+
3758
+ const int8_t * q8 = y[i].qs;
3759
+ const uint8_t * qs = x[i].qs;
3760
+ const uint16_t * qh = x[i].qh;
3761
+
3762
+ int sumi1 = 0, sumi2 = 0, sumi3 = 0;
3763
+
3764
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
3765
+
3766
+ q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[ib+0] << 8) & 0x700)))),
3767
+ vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[ib+0] << 5) & 0x700)))));
3768
+ q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[ib+0] << 2) & 0x700)))),
3769
+ vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[ib+0] >> 1) & 0x700)))));
3770
+ q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[ib+1] << 8) & 0x700)))),
3771
+ vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[ib+1] << 5) & 0x700)))));
3772
+ q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[ib+1] << 2) & 0x700)))),
3773
+ vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[ib+1] >> 1) & 0x700)))));
3774
+ qs += 8;
3775
+
3776
+ q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64;
3777
+
3778
+ const int32x4_t p1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[0], q8b.val[0]), q1b.val[1], q8b.val[1]);
3779
+ const int32x4_t p2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[2], q8b.val[2]), q1b.val[3], q8b.val[3]);
3780
+
3781
+ const int ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
3782
+ const int ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
3783
+ sumi1 += vaddvq_s32(p1) * ls1;
3784
+ sumi2 += vaddvq_s32(p2) * ls2;
3785
+ sumi3 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * ls1 * (qh[ib+0] & 0x8000 ? -1 : 1)
3786
+ + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * ls2 * (qh[ib+1] & 0x8000 ? -1 : 1);
3787
+
3788
+ }
3789
+
3790
+ sumf += y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d) * (sumi1 + sumi2 + IQ1S_DELTA * sumi3);
3791
+ }
3792
+
3793
+ *s = sumf;
3794
+
3795
+ #else
3796
+
3797
+ float sumf = 0;
3798
+ for (int i = 0; i < nb; i++) {
3799
+
3800
+ const int8_t * q8 = y[i].qs;
3801
+ const uint8_t * qs = x[i].qs;
3802
+ const uint16_t * qh = x[i].qh;
3803
+
3804
+ int sumi = 0, sumi1 = 0;
3805
+ for (int ib = 0; ib < QK_K/32; ++ib) {
3806
+ const int ls = 2*((qh[ib] >> 12) & 7) + 1;
3807
+ const int delta = qh[ib] & 0x8000 ? -1 : 1;
3808
+ int lsum = 0;
3809
+ for (int l = 0; l < 4; ++l) {
3810
+ const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8)));
3811
+ for (int j = 0; j < 8; ++j) {
3812
+ lsum += q8[j] * grid[j];
3813
+ }
3814
+ q8 += 8;
3815
+ }
3816
+ sumi += ls * lsum;
3817
+ sumi1 += ls * delta * (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]);
3818
+ qs += 4;
3819
+ }
3820
+
3821
+ sumf += LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1);
3822
+ }
3823
+
3824
+ *s = sumf;
3825
+
3826
+ #endif
3827
+ }
3828
+
3829
+ void lm_ggml_vec_dot_iq1_m_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
3830
+ assert(n % QK_K == 0);
3831
+ assert(nrc == 1);
3832
+ UNUSED(nrc);
3833
+ UNUSED(bx);
3834
+ UNUSED(by);
3835
+ UNUSED(bs);
3836
+
3837
+ const block_iq1_m * LM_GGML_RESTRICT x = vx;
3838
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
3839
+
3840
+ const int nb = n / QK_K;
3841
+
3842
+ iq1m_scale_t scale;
3843
+
3844
+ #if defined __ARM_NEON
3845
+ const int32x4_t mask = vdupq_n_s32(0x7);
3846
+ const int32x4_t mone = vdupq_n_s32(1);
3847
+ const int32x4_t mzero = vdupq_n_s32(0);
3848
+
3849
+ lm_ggml_int8x16x4_t deltas;
3850
+ deltas.val[0] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(+1));
3851
+ deltas.val[1] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(+1));
3852
+ deltas.val[2] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(-1));
3853
+ deltas.val[3] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(-1));
3854
+
3855
+ lm_ggml_int8x16x4_t q1b;
3856
+ lm_ggml_int8x16x4_t q8b;
3857
+
3858
+ uint32_t aux32;
3859
+ const uint8_t * aux8 = (const uint8_t *)&aux32;
3860
+
3861
+ float sumf = 0;
3862
+ for (int i = 0; i < nb; ++i) {
3863
+
3864
+ const int8_t * q8 = y[i].qs;
3865
+ const uint8_t * qs = x[i].qs;
3866
+ const uint8_t * qh = x[i].qh;
3867
+ const uint16_t * sc = (const uint16_t *)x[i].scales;
3868
+
3869
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
3870
+
3871
+ int32x4_t sumi1 = mzero;
3872
+ int32x4_t sumi2 = mzero;
3873
+
3874
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
3875
+
3876
+ q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[0] << 8) & 0x700)))),
3877
+ vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[0] << 4) & 0x700)))));
3878
+ q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[1] << 8) & 0x700)))),
3879
+ vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[1] << 4) & 0x700)))));
3880
+ q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[2] << 8) & 0x700)))),
3881
+ vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[2] << 4) & 0x700)))));
3882
+ q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[3] << 8) & 0x700)))),
3883
+ vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[3] << 4) & 0x700)))));
3884
+
3885
+ q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64;
3886
+
3887
+ const int32x4_t p1 = vpaddq_s32(lm_ggml_vdotq_s32(mzero, q1b.val[0], q8b.val[0]), lm_ggml_vdotq_s32(mzero, q1b.val[1], q8b.val[1]));
3888
+ const int32x4_t p2 = vpaddq_s32(lm_ggml_vdotq_s32(mzero, q1b.val[2], q8b.val[2]), lm_ggml_vdotq_s32(mzero, q1b.val[3], q8b.val[3]));
3889
+ const int32x4_t p12 = vpaddq_s32(p1, p2);
3890
+
3891
+ const uint32_t * qh32 = (const uint32_t *)qh; // we are 4-byte aligned, so we can do that
3892
+ aux32 = ((qh32[0] >> 3) & 0x01010101) | ((qh32[0] >> 6) & 0x02020202);
3893
+
3894
+ const int32x4_t p3 = vpaddq_s32(lm_ggml_vdotq_s32(mzero, deltas.val[aux8[0]], q8b.val[0]), lm_ggml_vdotq_s32(mzero, deltas.val[aux8[1]], q8b.val[1]));
3895
+ const int32x4_t p4 = vpaddq_s32(lm_ggml_vdotq_s32(mzero, deltas.val[aux8[2]], q8b.val[2]), lm_ggml_vdotq_s32(mzero, deltas.val[aux8[3]], q8b.val[3]));
3896
+ const int32x4_t p34 = vpaddq_s32(p3, p4);
3897
+
3898
+ int32x4_t scales_4 = lm_ggml_vld1q_u32(sc[ib/2] >> 0, sc[ib/2] >> 3, sc[ib/2] >> 6, sc[ib/2] >> 9);
3899
+
3900
+ scales_4 = vaddq_s32(vshlq_n_s32(vandq_s32(scales_4, mask), 1), mone);
3901
+
3902
+ sumi1 = vmlaq_s32(sumi1, scales_4, p12);
3903
+ sumi2 = vmlaq_s32(sumi2, scales_4, p34);
3904
+
3905
+ qs += 8; qh += 4;
3906
+
3907
+ }
3908
+
3909
+ sumf += y[i].d * LM_GGML_CPU_FP16_TO_FP32(scale.f16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2));
3910
+ }
3911
+
3912
+ *s = sumf;
3913
+
3914
+ #else
3915
+
3916
+ int sum1[2], sum2[2], delta[4];
3917
+
3918
+ float sumf = 0;
3919
+ for (int i = 0; i < nb; i++) {
3920
+
3921
+ const int8_t * q8 = y[i].qs;
3922
+ const uint8_t * qs = x[i].qs;
3923
+ const uint8_t * qh = x[i].qh;
3924
+ const uint16_t * sc = (const uint16_t *)x[i].scales;
3925
+
3926
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
3927
+
3928
+ int sumi1 = 0, sumi2 = 0;
3929
+ for (int ib = 0; ib < QK_K/32; ++ib) {
3930
+ delta[0] = qh[0] & 0x08 ? -1 : 1;
3931
+ delta[1] = qh[0] & 0x80 ? -1 : 1;
3932
+ delta[2] = qh[1] & 0x08 ? -1 : 1;
3933
+ delta[3] = qh[1] & 0x80 ? -1 : 1;
3934
+ sum1[0] = sum1[1] = sum2[0] = sum2[1] = 0;
3935
+ for (int l = 0; l < 4; ++l) {
3936
+ const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((uint16_t)qh[l/2] << (8 - 4*(l%2))) & 0x700)));
3937
+ int lsum1 = 0, lsum2 = 0;
3938
+ for (int j = 0; j < 8; ++j) {
3939
+ lsum1 += q8[j] * grid[j];
3940
+ lsum2 += q8[j];
3941
+ }
3942
+ q8 += 8;
3943
+ sum1[l/2] += lsum1;
3944
+ sum2[l/2] += lsum2*delta[l];
3945
+ }
3946
+
3947
+ const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1;
3948
+ const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1;
3949
+
3950
+ sumi1 += sum1[0] * ls1 + sum1[1] * ls2;
3951
+ sumi2 += sum2[0] * ls1 + sum2[1] * ls2;
3952
+ qs += 4;
3953
+ qh += 2;
3954
+ }
3955
+
3956
+ sumf += LM_GGML_CPU_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2);
3957
+ }
3958
+
3959
+ *s = sumf;
3960
+
3961
+ #endif
3962
+ }
3963
+
3964
+ void lm_ggml_vec_dot_iq4_nl_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
3965
+ assert(nrc == 1);
3966
+ UNUSED(nrc);
3967
+ UNUSED(bx);
3968
+ UNUSED(by);
3969
+ UNUSED(bs);
3970
+ assert(n % QK4_NL == 0);
3971
+ static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same");
3972
+
3973
+ const block_iq4_nl * LM_GGML_RESTRICT x = vx;
3974
+ const block_q8_0 * LM_GGML_RESTRICT y = vy;
3975
+
3976
+ const int nb = n / QK4_NL;
3977
+
3978
+ int ib = 0;
3979
+ float sumf = 0;
3980
+
3981
+ #if defined __ARM_NEON
3982
+ const int8x16_t values = vld1q_s8(kvalues_iq4nl);
3983
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
3984
+ uint8x16x2_t q4bits;
3985
+ int8x16x4_t q4b;
3986
+ int8x16x4_t q8b;
3987
+ int32x4_t prod_1, prod_2;
3988
+
3989
+ for (; ib + 1 < nb; ib += 2) {
3990
+
3991
+ q4bits.val[0] = vld1q_u8(x[ib + 0].qs);
3992
+ q4bits.val[1] = vld1q_u8(x[ib + 1].qs);
3993
+ q8b.val[0] = vld1q_s8(y[ib + 0].qs);
3994
+ q8b.val[1] = vld1q_s8(y[ib + 0].qs + 16);
3995
+ q8b.val[2] = vld1q_s8(y[ib + 1].qs);
3996
+ q8b.val[3] = vld1q_s8(y[ib + 1].qs + 16);
3997
+
3998
+ q4b.val[0] = lm_ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b));
3999
+ q4b.val[1] = lm_ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
4000
+ q4b.val[2] = lm_ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b));
4001
+ q4b.val[3] = lm_ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
4002
+
4003
+ prod_1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
4004
+ prod_2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
4005
+
4006
+ sumf +=
4007
+ LM_GGML_CPU_FP16_TO_FP32(x[ib+0].d) * LM_GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) +
4008
+ LM_GGML_CPU_FP16_TO_FP32(x[ib+1].d) * LM_GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2);
4009
+ }
4010
+
4011
+ #endif
4012
+ for (; ib < nb; ++ib) {
4013
+ const float d = LM_GGML_CPU_FP16_TO_FP32(y[ib].d)*LM_GGML_CPU_FP16_TO_FP32(x[ib].d);
4014
+ int sumi1 = 0, sumi2 = 0;
4015
+ for (int j = 0; j < QK4_NL/2; ++j) {
4016
+ sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf];
4017
+ sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4];
4018
+ }
4019
+ sumf += d * (sumi1 + sumi2);
4020
+ }
4021
+ *s = sumf;
4022
+ }
4023
+
4024
+ void lm_ggml_vec_dot_iq4_xs_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
4025
+ assert(nrc == 1);
4026
+ UNUSED(nrc);
4027
+ UNUSED(bx);
4028
+ UNUSED(by);
4029
+ UNUSED(bs);
4030
+ assert(n % QK_K == 0);
4031
+
4032
+ const block_iq4_xs * LM_GGML_RESTRICT x = vx;
4033
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
4034
+
4035
+ const int nb = n / QK_K;
4036
+
4037
+ #if defined __ARM_NEON
4038
+ const int8x16_t values = vld1q_s8(kvalues_iq4nl);
4039
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
4040
+ lm_ggml_uint8x16x2_t q4bits;
4041
+ lm_ggml_int8x16x4_t q4b;
4042
+ lm_ggml_int8x16x4_t q8b;
4043
+ int32x4_t prod_1, prod_2;
4044
+
4045
+ float sumf = 0;
4046
+
4047
+ for (int ibl = 0; ibl < nb; ++ibl) {
4048
+
4049
+ const int8_t * q8 = y[ibl].qs;
4050
+ const uint8_t * q4 = x[ibl].qs;
4051
+ uint16_t h = x[ibl].scales_h;
4052
+
4053
+ int sumi1 = 0, sumi2 = 0;
4054
+ for (int ib = 0; ib < QK_K/64; ++ib) {
4055
+
4056
+ q4bits = lm_ggml_vld1q_u8_x2(q4); q4 += 32;
4057
+ q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64;
4058
+
4059
+ q4b.val[0] = lm_ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b));
4060
+ q4b.val[1] = lm_ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
4061
+ q4b.val[2] = lm_ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b));
4062
+ q4b.val[3] = lm_ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
4063
+
4064
+ prod_1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
4065
+ prod_2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
4066
+
4067
+ int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32;
4068
+ int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32;
4069
+ h >>= 4;
4070
+ sumi1 += vaddvq_s32(prod_1) * ls1;
4071
+ sumi2 += vaddvq_s32(prod_2) * ls2;
4072
+
4073
+ }
4074
+
4075
+ sumf += LM_GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2);
4076
+ }
4077
+
4078
+ *s = sumf;
4079
+
4080
+ #else
4081
+ float sumf = 0;
4082
+ for (int ibl = 0; ibl < nb; ++ibl) {
4083
+ const float d4d8 = LM_GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d;
4084
+ uint16_t h = x[ibl].scales_h;
4085
+ const uint8_t * qs = x[ibl].qs;
4086
+ const int8_t * q8 = y[ibl].qs;
4087
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
4088
+ const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30);
4089
+ const uint8_t ls2 = (x[ibl].scales_l[ib/2] >> 4) | ((h << 2) & 0x30);
4090
+ h >>= 4;
4091
+ const float d1 = d4d8*(ls1 - 32);
4092
+ const float d2 = d4d8*(ls2 - 32);
4093
+ int sumi1 = 0, sumi2 = 0;
4094
+ for (int j = 0; j < 16; ++j) {
4095
+ sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
4096
+ sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4];
4097
+ }
4098
+ sumf += d1 * (sumi1 + sumi2);
4099
+ qs += 16;
4100
+ q8 += 32;
4101
+ sumi1 = sumi2 = 0;
4102
+ for (int j = 0; j < 16; ++j) {
4103
+ sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
4104
+ sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4];
4105
+ }
4106
+ sumf += d2 * (sumi1 + sumi2);
4107
+ qs += 16;
4108
+ q8 += 32;
4109
+ }
4110
+ }
4111
+ *s = sumf;
4112
+ #endif
4113
+ }
4114
+