cui-llama.rn 1.7.4 → 1.7.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (276) hide show
  1. package/README.md +217 -17
  2. package/android/src/main/CMakeLists.txt +34 -15
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +79 -5
  4. package/android/src/main/java/com/rnllama/RNLlama.java +237 -0
  5. package/android/src/main/jni.cpp +213 -14
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  14. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +35 -0
  15. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +34 -0
  16. package/cpp/README.md +1 -1
  17. package/cpp/chat-parser.cpp +385 -0
  18. package/cpp/chat-parser.h +120 -0
  19. package/cpp/chat.cpp +726 -596
  20. package/cpp/chat.h +71 -6
  21. package/cpp/common.cpp +56 -38
  22. package/cpp/common.h +9 -3
  23. package/cpp/ggml-backend-reg.cpp +5 -0
  24. package/cpp/ggml-backend.cpp +10 -2
  25. package/cpp/ggml-common.h +4 -0
  26. package/cpp/ggml-cpu/amx/amx.cpp +1 -1
  27. package/cpp/ggml-cpu/amx/mmq.cpp +11 -10
  28. package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  29. package/cpp/ggml-cpu/arch/arm/quants.c +4114 -0
  30. package/cpp/ggml-cpu/arch/arm/repack.cpp +2163 -0
  31. package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
  32. package/cpp/ggml-cpu/arch/x86/quants.c +4311 -0
  33. package/cpp/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  34. package/cpp/ggml-cpu/arch-fallback.h +184 -0
  35. package/cpp/ggml-cpu/common.h +4 -3
  36. package/cpp/ggml-cpu/ggml-cpu-impl.h +21 -16
  37. package/cpp/ggml-cpu/ggml-cpu.c +123 -104
  38. package/cpp/ggml-cpu/ggml-cpu.cpp +11 -8
  39. package/cpp/ggml-cpu/ops.cpp +330 -148
  40. package/cpp/ggml-cpu/ops.h +1 -0
  41. package/cpp/ggml-cpu/quants.c +1158 -0
  42. package/cpp/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  43. package/cpp/ggml-cpu/repack.cpp +1571 -0
  44. package/cpp/ggml-cpu/repack.h +98 -0
  45. package/cpp/ggml-cpu/simd-mappings.h +330 -38
  46. package/cpp/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  47. package/cpp/ggml-cpu/vec.cpp +87 -18
  48. package/cpp/ggml-cpu/vec.h +249 -94
  49. package/cpp/ggml-cpu.h +1 -0
  50. package/cpp/ggml-impl.h +63 -183
  51. package/cpp/ggml-llama-sim.metallib +0 -0
  52. package/cpp/ggml-llama.metallib +0 -0
  53. package/cpp/ggml-metal.m +152 -45
  54. package/cpp/ggml-quants.c +0 -2
  55. package/cpp/ggml.c +61 -21
  56. package/cpp/ggml.h +22 -3
  57. package/cpp/gguf.cpp +24 -3
  58. package/cpp/json-partial.cpp +256 -0
  59. package/cpp/json-partial.h +38 -0
  60. package/cpp/json-schema-to-grammar.cpp +5 -47
  61. package/cpp/json-schema-to-grammar.h +4 -4
  62. package/cpp/llama-arch.cpp +153 -3
  63. package/cpp/llama-arch.h +27 -1
  64. package/cpp/llama-batch.cpp +741 -272
  65. package/cpp/llama-batch.h +112 -54
  66. package/cpp/llama-chat.cpp +30 -8
  67. package/cpp/llama-chat.h +1 -0
  68. package/cpp/llama-context.cpp +524 -339
  69. package/cpp/llama-context.h +38 -17
  70. package/cpp/llama-cparams.cpp +4 -0
  71. package/cpp/llama-cparams.h +2 -0
  72. package/cpp/llama-grammar.cpp +12 -2
  73. package/cpp/llama-graph.cpp +431 -356
  74. package/cpp/llama-graph.h +126 -58
  75. package/cpp/llama-hparams.cpp +10 -2
  76. package/cpp/llama-hparams.h +19 -2
  77. package/cpp/llama-kv-cache-unified-iswa.cpp +279 -0
  78. package/cpp/llama-kv-cache-unified-iswa.h +128 -0
  79. package/cpp/llama-kv-cache-unified.cpp +1841 -0
  80. package/cpp/llama-kv-cache-unified.h +303 -0
  81. package/cpp/llama-kv-cells.h +439 -0
  82. package/cpp/llama-memory-hybrid.cpp +246 -0
  83. package/cpp/llama-memory-hybrid.h +138 -0
  84. package/cpp/llama-memory-recurrent.cpp +1112 -0
  85. package/cpp/llama-memory-recurrent.h +183 -0
  86. package/cpp/llama-memory.cpp +41 -0
  87. package/cpp/llama-memory.h +86 -5
  88. package/cpp/llama-mmap.cpp +1 -1
  89. package/cpp/llama-model-loader.cpp +42 -17
  90. package/cpp/llama-model-saver.cpp +1 -0
  91. package/cpp/llama-model.cpp +1639 -513
  92. package/cpp/llama-model.h +26 -0
  93. package/cpp/llama-sampling.cpp +2 -2
  94. package/cpp/llama-vocab.cpp +65 -28
  95. package/cpp/llama-vocab.h +1 -0
  96. package/cpp/llama.cpp +11 -7
  97. package/cpp/llama.h +150 -42
  98. package/cpp/minja/chat-template.hpp +1 -1
  99. package/cpp/minja/minja.hpp +1 -1
  100. package/cpp/{json.hpp → nlohmann/json.hpp} +3027 -2267
  101. package/cpp/nlohmann/json_fwd.hpp +187 -0
  102. package/cpp/regex-partial.cpp +204 -0
  103. package/cpp/regex-partial.h +56 -0
  104. package/cpp/rn-llama.cpp +646 -35
  105. package/cpp/rn-llama.h +32 -1
  106. package/cpp/rn-tts.h +39 -0
  107. package/cpp/sampling.cpp +7 -8
  108. package/cpp/tools/mtmd/clip-impl.h +5 -0
  109. package/cpp/tools/mtmd/clip.cpp +572 -436
  110. package/cpp/tools/mtmd/clip.h +14 -4
  111. package/cpp/tools/mtmd/mtmd-audio.cpp +0 -86
  112. package/cpp/tools/mtmd/mtmd-audio.h +2 -17
  113. package/cpp/tools/mtmd/mtmd-helper.cpp +175 -12
  114. package/cpp/tools/mtmd/mtmd-helper.h +91 -0
  115. package/cpp/tools/mtmd/mtmd.cpp +368 -248
  116. package/cpp/tools/mtmd/mtmd.h +6 -70
  117. package/cpp/unicode.cpp +5 -0
  118. package/ios/CMakeLists.txt +26 -6
  119. package/ios/RNLlama.h +1 -1
  120. package/ios/RNLlama.mm +153 -3
  121. package/ios/RNLlamaContext.h +9 -1
  122. package/ios/RNLlamaContext.mm +112 -9
  123. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat-parser.h +120 -0
  124. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +71 -6
  125. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +9 -3
  126. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-common.h +4 -0
  127. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +1 -0
  128. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +63 -183
  129. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +22 -3
  130. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-partial.h +38 -0
  131. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  132. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +27 -1
  133. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +112 -54
  134. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  135. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +38 -17
  136. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +2 -0
  137. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +126 -58
  138. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +19 -2
  139. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  140. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  141. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  142. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  143. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  144. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +86 -5
  145. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +26 -0
  146. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +1 -0
  147. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +150 -42
  148. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  149. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +1 -1
  150. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/{json.hpp → nlohmann/json.hpp} +3027 -2267
  151. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  152. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/regex-partial.h +56 -0
  153. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +32 -1
  154. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-tts.h +39 -0
  155. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  156. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  157. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat-parser.h +120 -0
  158. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +71 -6
  159. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +9 -3
  160. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +4 -0
  161. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +1 -0
  162. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +63 -183
  163. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +22 -3
  164. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-partial.h +38 -0
  165. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  166. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +27 -1
  167. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +112 -54
  168. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  169. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +38 -17
  170. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +2 -0
  171. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +126 -58
  172. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +19 -2
  173. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  174. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  175. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  176. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  177. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  178. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +86 -5
  179. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +26 -0
  180. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +1 -0
  181. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +150 -42
  182. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  183. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +1 -1
  184. package/ios/rnllama.xcframework/{tvos-arm64/rnllama.framework/Headers → ios-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann}/json.hpp +3027 -2267
  185. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  186. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/regex-partial.h +56 -0
  187. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +32 -1
  188. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-tts.h +39 -0
  189. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  190. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  191. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat-parser.h +120 -0
  192. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +71 -6
  193. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +9 -3
  194. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-common.h +4 -0
  195. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +1 -0
  196. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +63 -183
  197. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +22 -3
  198. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-partial.h +38 -0
  199. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  200. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +27 -1
  201. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +112 -54
  202. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  203. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +38 -17
  204. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +2 -0
  205. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +126 -58
  206. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +19 -2
  207. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  208. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  209. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  210. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  211. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  212. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +86 -5
  213. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +26 -0
  214. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +1 -0
  215. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +150 -42
  216. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  217. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +1 -1
  218. package/ios/rnllama.xcframework/{ios-arm64_x86_64-simulator/rnllama.framework/Headers → tvos-arm64/rnllama.framework/Headers/nlohmann}/json.hpp +3027 -2267
  219. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  220. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/regex-partial.h +56 -0
  221. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +32 -1
  222. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-tts.h +39 -0
  223. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  224. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  225. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat-parser.h +120 -0
  226. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +71 -6
  227. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +9 -3
  228. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +4 -0
  229. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +1 -0
  230. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +63 -183
  231. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +22 -3
  232. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-partial.h +38 -0
  233. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  234. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +27 -1
  235. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +112 -54
  236. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  237. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +38 -17
  238. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +2 -0
  239. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +126 -58
  240. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +19 -2
  241. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  242. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  243. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  244. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  245. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  246. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +86 -5
  247. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +26 -0
  248. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +1 -0
  249. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +150 -42
  250. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  251. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +1 -1
  252. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json.hpp +25526 -0
  253. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  254. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/regex-partial.h +56 -0
  255. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +32 -1
  256. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-tts.h +39 -0
  257. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  258. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  259. package/jest/mock.js +24 -0
  260. package/package.json +1 -1
  261. package/src/NativeRNLlama.ts +46 -2
  262. package/src/index.ts +105 -1
  263. package/cpp/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  264. package/cpp/ggml-cpu/ggml-cpu-quants.c +0 -13326
  265. package/cpp/ggml-cpu/sgemm.cpp +0 -3544
  266. package/cpp/ggml-cpu/sgemm.h +0 -14
  267. package/cpp/llama-kv-cache.cpp +0 -2827
  268. package/cpp/llama-kv-cache.h +0 -515
  269. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  270. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  271. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  272. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +0 -24766
  273. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  274. /package/cpp/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  275. /package/cpp/tools/mtmd/{miniaudio.h → miniaudio/miniaudio.h} +0 -0
  276. /package/cpp/tools/mtmd/{stb_image.h → stb/stb_image.h} +0 -0
@@ -0,0 +1,4311 @@
1
+ #define LM_GGML_COMMON_IMPL_C
2
+ #include "ggml-common.h"
3
+ #include "ggml-quants.h"
4
+ #include "ggml-impl.h"
5
+ #include "ggml-cpu.h"
6
+ #include "simd-mappings.h"
7
+
8
+ #include "../../quants.h"
9
+ #include "../../ggml-cpu-impl.h"
10
+
11
+ #include <math.h>
12
+ #include <string.h>
13
+ #include <assert.h>
14
+ #include <stdlib.h> // for qsort
15
+ #include <stdio.h> // for LM_GGML_ASSERT
16
+
17
+ #define GROUP_MAX_EPS 1e-15f
18
+ #define GROUP_MAX_EPS_IQ3_XXS 1e-8f
19
+ #define GROUP_MAX_EPS_IQ2_S 1e-8f
20
+ #define GROUP_MAX_EPS_IQ1_M 1e-7f
21
+ #define GROUP_MAX_EPS_IQ1_S 1e-12f
22
+
23
+ #define UNUSED LM_GGML_UNUSED
24
+
25
+ // some compilers don't provide _mm256_set_m128i, e.g. gcc 7
26
+ #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
27
+
28
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
29
+ // multiply int8_t, add results pairwise twice
30
+ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
31
+ // Get absolute values of x vectors
32
+ const __m128i ax = _mm_sign_epi8(x, x);
33
+ // Sign the values of the y vectors
34
+ const __m128i sy = _mm_sign_epi8(y, x);
35
+ // Perform multiplication and create 16-bit values
36
+ const __m128i dot = _mm_maddubs_epi16(ax, sy);
37
+ const __m128i ones = _mm_set1_epi16(1);
38
+ return _mm_madd_epi16(ones, dot);
39
+ }
40
+
41
+ #if __AVX__ || __AVX2__ || __AVX512F__
42
+ // horizontally add 8 floats
43
+ static inline float hsum_float_8(const __m256 x) {
44
+ __m128 res = _mm256_extractf128_ps(x, 1);
45
+ res = _mm_add_ps(res, _mm256_castps256_ps128(x));
46
+ res = _mm_add_ps(res, _mm_movehl_ps(res, res));
47
+ res = _mm_add_ss(res, _mm_movehdup_ps(res));
48
+ return _mm_cvtss_f32(res);
49
+ }
50
+
51
+ // horizontally add 8 int32_t
52
+ static inline int hsum_i32_8(const __m256i a) {
53
+ const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
54
+ const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
55
+ const __m128i sum64 = _mm_add_epi32(hi64, sum128);
56
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
57
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
58
+ }
59
+
60
+ // horizontally add 4 int32_t
61
+ static inline int hsum_i32_4(const __m128i a) {
62
+ const __m128i hi64 = _mm_unpackhi_epi64(a, a);
63
+ const __m128i sum64 = _mm_add_epi32(hi64, a);
64
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
65
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
66
+ }
67
+
68
+ #if defined(__AVX2__) || defined(__AVX512F__)
69
+ // spread 32 bits to 32 bytes { 0x00, 0xFF }
70
+ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
71
+ uint32_t x32;
72
+ memcpy(&x32, x, sizeof(uint32_t));
73
+ const __m256i shuf_mask = _mm256_set_epi64x(
74
+ 0x0303030303030303, 0x0202020202020202,
75
+ 0x0101010101010101, 0x0000000000000000);
76
+ __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
77
+ const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
78
+ bytes = _mm256_or_si256(bytes, bit_mask);
79
+ return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
80
+ }
81
+
82
+ // Unpack 32 4-bit fields into 32 bytes
83
+ // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
84
+ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
85
+ {
86
+ const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
87
+ const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
88
+ const __m256i lowMask = _mm256_set1_epi8( 0xF );
89
+ return _mm256_and_si256(lowMask, bytes);
90
+ }
91
+
92
+ // add int16_t pairwise and return as float vector
93
+ static inline __m256 sum_i16_pairs_float(const __m256i x) {
94
+ const __m256i ones = _mm256_set1_epi16(1);
95
+ const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
96
+ return _mm256_cvtepi32_ps(summed_pairs);
97
+ }
98
+
99
+ static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
100
+ #if defined(__AVX512VNNI__) && defined(__AVX512VL__)
101
+ const __m256i zero = _mm256_setzero_si256();
102
+ const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
103
+ return _mm256_cvtepi32_ps(summed_pairs);
104
+ #elif defined(__AVXVNNI__)
105
+ const __m256i zero = _mm256_setzero_si256();
106
+ const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy);
107
+ return _mm256_cvtepi32_ps(summed_pairs);
108
+ #else
109
+ // Perform multiplication and create 16-bit values
110
+ const __m256i dot = _mm256_maddubs_epi16(ax, sy);
111
+ return sum_i16_pairs_float(dot);
112
+ #endif
113
+ }
114
+
115
+ // multiply int8_t, add results pairwise twice and return as float vector
116
+ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
117
+ #if __AVXVNNIINT8__
118
+ const __m256i zero = _mm256_setzero_si256();
119
+ const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);
120
+ return _mm256_cvtepi32_ps(summed_pairs);
121
+ #else
122
+ // Get absolute values of x vectors
123
+ const __m256i ax = _mm256_sign_epi8(x, x);
124
+ // Sign the values of the y vectors
125
+ const __m256i sy = _mm256_sign_epi8(y, x);
126
+ return mul_sum_us8_pairs_float(ax, sy);
127
+ #endif
128
+ }
129
+
130
+ static inline __m128i packNibbles( __m256i bytes )
131
+ {
132
+ // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
133
+ #if __AVX512F__
134
+ const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000
135
+ bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh
136
+ return _mm256_cvtepi16_epi8(bytes); // abcd_efgh
137
+ #else
138
+ const __m256i lowByte = _mm256_set1_epi16( 0xFF );
139
+ __m256i high = _mm256_andnot_si256( lowByte, bytes );
140
+ __m256i low = _mm256_and_si256( lowByte, bytes );
141
+ high = _mm256_srli_epi16( high, 4 );
142
+ bytes = _mm256_or_si256( low, high );
143
+
144
+ // Compress uint16_t lanes into bytes
145
+ __m128i r0 = _mm256_castsi256_si128( bytes );
146
+ __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
147
+ return _mm_packus_epi16( r0, r1 );
148
+ #endif
149
+ }
150
+ #elif defined(__AVX__)
151
+ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
152
+ {
153
+ // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
154
+ const __m128i lowByte = _mm_set1_epi16( 0xFF );
155
+ __m128i high = _mm_andnot_si128( lowByte, bytes1 );
156
+ __m128i low = _mm_and_si128( lowByte, bytes1 );
157
+ high = _mm_srli_epi16( high, 4 );
158
+ bytes1 = _mm_or_si128( low, high );
159
+ high = _mm_andnot_si128( lowByte, bytes2 );
160
+ low = _mm_and_si128( lowByte, bytes2 );
161
+ high = _mm_srli_epi16( high, 4 );
162
+ bytes2 = _mm_or_si128( low, high );
163
+
164
+ return _mm_packus_epi16( bytes1, bytes2);
165
+ }
166
+
167
+ static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {
168
+ const __m128i ax = _mm_sign_epi8(x, x);
169
+ const __m128i sy = _mm_sign_epi8(y, x);
170
+ return _mm_maddubs_epi16(ax, sy);
171
+ }
172
+
173
+ // spread 32 bits to 32 bytes { 0x00, 0xFF }
174
+ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
175
+ uint32_t x32;
176
+ memcpy(&x32, x, sizeof(uint32_t));
177
+ const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
178
+ const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
179
+ __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
180
+ __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
181
+ const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
182
+ bytesl = _mm_or_si128(bytesl, bit_mask);
183
+ bytesh = _mm_or_si128(bytesh, bit_mask);
184
+ bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
185
+ bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
186
+ return MM256_SET_M128I(bytesh, bytesl);
187
+ }
188
+
189
+ // Unpack 32 4-bit fields into 32 bytes
190
+ // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
191
+ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
192
+ {
193
+ // Load 16 bytes from memory
194
+ __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
195
+ __m128i tmph = _mm_srli_epi16(tmpl, 4);
196
+ const __m128i lowMask = _mm_set1_epi8(0xF);
197
+ tmpl = _mm_and_si128(lowMask, tmpl);
198
+ tmph = _mm_and_si128(lowMask, tmph);
199
+ return MM256_SET_M128I(tmph, tmpl);
200
+ }
201
+
202
+ // add int16_t pairwise and return as float vector
203
+ static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
204
+ const __m128i ones = _mm_set1_epi16(1);
205
+ const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
206
+ const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
207
+ const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
208
+ return _mm256_cvtepi32_ps(summed_pairs);
209
+ }
210
+
211
+ static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
212
+ const __m128i axl = _mm256_castsi256_si128(ax);
213
+ const __m128i axh = _mm256_extractf128_si256(ax, 1);
214
+ const __m128i syl = _mm256_castsi256_si128(sy);
215
+ const __m128i syh = _mm256_extractf128_si256(sy, 1);
216
+ // Perform multiplication and create 16-bit values
217
+ const __m128i dotl = _mm_maddubs_epi16(axl, syl);
218
+ const __m128i doth = _mm_maddubs_epi16(axh, syh);
219
+ return sum_i16_pairs_float(doth, dotl);
220
+ }
221
+
222
+ // multiply int8_t, add results pairwise twice and return as float vector
223
+ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
224
+ const __m128i xl = _mm256_castsi256_si128(x);
225
+ const __m128i xh = _mm256_extractf128_si256(x, 1);
226
+ const __m128i yl = _mm256_castsi256_si128(y);
227
+ const __m128i yh = _mm256_extractf128_si256(y, 1);
228
+ // Get absolute values of x vectors
229
+ const __m128i axl = _mm_sign_epi8(xl, xl);
230
+ const __m128i axh = _mm_sign_epi8(xh, xh);
231
+ // Sign the values of the y vectors
232
+ const __m128i syl = _mm_sign_epi8(yl, xl);
233
+ const __m128i syh = _mm_sign_epi8(yh, xh);
234
+ // Perform multiplication and create 16-bit values
235
+ const __m128i dotl = _mm_maddubs_epi16(axl, syl);
236
+ const __m128i doth = _mm_maddubs_epi16(axh, syh);
237
+ return sum_i16_pairs_float(doth, dotl);
238
+ }
239
+
240
+ // larger version of mul_sum_i8_pairs_float where x and y are each represented by four 128-bit vectors
241
+ static inline __m256 mul_sum_i8_quad_float(const __m128i x_1_0, const __m128i x_1_1, const __m128i x_2_0, const __m128i x_2_1,
242
+ const __m128i y_1_0, const __m128i y_1_1, const __m128i y_2_0, const __m128i y_2_1) {
243
+ const __m128i mone = _mm_set1_epi16(1);
244
+
245
+ const __m128i p16_1_0 = mul_add_epi8_sse(x_1_0, y_1_0);
246
+ const __m128i p16_1_1 = mul_add_epi8_sse(x_1_1, y_1_1);
247
+ const __m128i p16_2_0 = mul_add_epi8_sse(x_2_0, y_2_0);
248
+ const __m128i p16_2_1 = mul_add_epi8_sse(x_2_1, y_2_1);
249
+ const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone);
250
+ const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone);
251
+ const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone);
252
+ const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone);
253
+ const __m128i p_1 = _mm_add_epi32(p_1_0, p_1_1);
254
+ const __m128i p_2 = _mm_add_epi32(p_2_0, p_2_1);
255
+ return _mm256_cvtepi32_ps(MM256_SET_M128I(p_2, p_1));
256
+ }
257
+
258
+ // quad fp16 delta calculation
259
+ static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const float x1, const float y1) {
260
+ // LM_GGML_CPU_FP16_TO_FP32 is faster than Intel F16C
261
+ return _mm256_set_m128(_mm_set1_ps(LM_GGML_CPU_FP16_TO_FP32(x1) * LM_GGML_CPU_FP16_TO_FP32(y1)),
262
+ _mm_set1_ps(LM_GGML_CPU_FP16_TO_FP32(x0) * LM_GGML_CPU_FP16_TO_FP32(y0)));
263
+ }
264
+ #endif
265
+ #elif defined(__SSSE3__)
266
+ // horizontally add 4x4 floats
267
+ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
268
+ __m128 res_0 =_mm_hadd_ps(a, b);
269
+ __m128 res_1 =_mm_hadd_ps(c, d);
270
+ __m128 res =_mm_hadd_ps(res_0, res_1);
271
+ res =_mm_hadd_ps(res, res);
272
+ res =_mm_hadd_ps(res, res);
273
+
274
+ return _mm_cvtss_f32(res);
275
+ }
276
+ #endif // __AVX__ || __AVX2__ || __AVX512F__
277
+ #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
278
+
279
+ void quantize_row_q8_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t k) {
280
+ assert(QK8_0 == 32);
281
+ assert(k % QK8_0 == 0);
282
+ const int nb = k / QK8_0;
283
+
284
+ block_q8_0 * LM_GGML_RESTRICT y = vy;
285
+
286
+ #if defined(__AVX2__) || defined(__AVX__)
287
+ for (int i = 0; i < nb; i++) {
288
+ // Load elements into 4 AVX vectors
289
+ __m256 v0 = _mm256_loadu_ps( x );
290
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
291
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
292
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
293
+ x += 32;
294
+
295
+ // Compute max(abs(e)) for the block
296
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
297
+ __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
298
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
299
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
300
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
301
+
302
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
303
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
304
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
305
+ const float maxScalar = _mm_cvtss_f32( max4 );
306
+
307
+ // Quantize these floats
308
+ const float d = maxScalar / 127.f;
309
+ y[i].d = LM_GGML_CPU_FP32_TO_FP16(d);
310
+ const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
311
+ const __m256 mul = _mm256_set1_ps( id );
312
+
313
+ // Apply the multiplier
314
+ v0 = _mm256_mul_ps( v0, mul );
315
+ v1 = _mm256_mul_ps( v1, mul );
316
+ v2 = _mm256_mul_ps( v2, mul );
317
+ v3 = _mm256_mul_ps( v3, mul );
318
+
319
+ // Round to nearest integer
320
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
321
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
322
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
323
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
324
+
325
+ // Convert floats to integers
326
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
327
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
328
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
329
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
330
+
331
+ #if defined(__AVX2__)
332
+ // Convert int32 to int16
333
+ i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
334
+ i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
335
+ // Convert int16 to int8
336
+ i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
337
+
338
+ // We got our precious signed bytes, but the order is now wrong
339
+ // These AVX2 pack instructions process 16-byte pieces independently
340
+ // The following instruction is fixing the order
341
+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
342
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
343
+
344
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
345
+ #else
346
+ // Since we don't have in AVX some necessary functions,
347
+ // we split the registers in half and call AVX2 analogs from SSE
348
+ __m128i ni0 = _mm256_castsi256_si128( i0 );
349
+ __m128i ni1 = _mm256_extractf128_si256( i0, 1);
350
+ __m128i ni2 = _mm256_castsi256_si128( i1 );
351
+ __m128i ni3 = _mm256_extractf128_si256( i1, 1);
352
+ __m128i ni4 = _mm256_castsi256_si128( i2 );
353
+ __m128i ni5 = _mm256_extractf128_si256( i2, 1);
354
+ __m128i ni6 = _mm256_castsi256_si128( i3 );
355
+ __m128i ni7 = _mm256_extractf128_si256( i3, 1);
356
+
357
+ // Convert int32 to int16
358
+ ni0 = _mm_packs_epi32( ni0, ni1 );
359
+ ni2 = _mm_packs_epi32( ni2, ni3 );
360
+ ni4 = _mm_packs_epi32( ni4, ni5 );
361
+ ni6 = _mm_packs_epi32( ni6, ni7 );
362
+ // Convert int16 to int8
363
+ ni0 = _mm_packs_epi16( ni0, ni2 );
364
+ ni4 = _mm_packs_epi16( ni4, ni6 );
365
+
366
+ _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
367
+ _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
368
+ #endif
369
+ }
370
+ #else
371
+ LM_GGML_UNUSED(nb);
372
+ // scalar
373
+ quantize_row_q8_0_ref(x, y, k);
374
+ #endif
375
+ }
376
+
377
+ void quantize_row_q8_1(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t k) {
378
+ assert(k % QK8_1 == 0);
379
+ const int nb = k / QK8_1;
380
+
381
+ block_q8_1 * LM_GGML_RESTRICT y = vy;
382
+ #if defined(__AVX2__) || defined(__AVX__)
383
+ for (int i = 0; i < nb; i++) {
384
+ // Load elements into 4 AVX vectors
385
+ __m256 v0 = _mm256_loadu_ps( x );
386
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
387
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
388
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
389
+ x += 32;
390
+
391
+ // Compute max(abs(e)) for the block
392
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
393
+ __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
394
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
395
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
396
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
397
+
398
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
399
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
400
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
401
+ const float max_scalar = _mm_cvtss_f32( max4 );
402
+
403
+ // Quantize these floats
404
+ const float d = max_scalar / 127.f;
405
+ y[i].d = LM_GGML_CPU_FP32_TO_FP16(d);
406
+ const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
407
+ const __m256 mul = _mm256_set1_ps( id );
408
+
409
+ // Apply the multiplier
410
+ v0 = _mm256_mul_ps( v0, mul );
411
+ v1 = _mm256_mul_ps( v1, mul );
412
+ v2 = _mm256_mul_ps( v2, mul );
413
+ v3 = _mm256_mul_ps( v3, mul );
414
+
415
+ // Round to nearest integer
416
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
417
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
418
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
419
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
420
+
421
+ // Convert floats to integers
422
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
423
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
424
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
425
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
426
+
427
+ #if defined(__AVX2__)
428
+ // Compute the sum of the quants and set y[i].s
429
+ y[i].s = LM_GGML_CPU_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
430
+
431
+ // Convert int32 to int16
432
+ i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
433
+ i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
434
+ // Convert int16 to int8
435
+ i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
436
+
437
+ // We got our precious signed bytes, but the order is now wrong
438
+ // These AVX2 pack instructions process 16-byte pieces independently
439
+ // The following instruction is fixing the order
440
+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
441
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
442
+
443
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
444
+ #else
445
+ // Since we don't have in AVX some necessary functions,
446
+ // we split the registers in half and call AVX2 analogs from SSE
447
+ __m128i ni0 = _mm256_castsi256_si128( i0 );
448
+ __m128i ni1 = _mm256_extractf128_si256( i0, 1);
449
+ __m128i ni2 = _mm256_castsi256_si128( i1 );
450
+ __m128i ni3 = _mm256_extractf128_si256( i1, 1);
451
+ __m128i ni4 = _mm256_castsi256_si128( i2 );
452
+ __m128i ni5 = _mm256_extractf128_si256( i2, 1);
453
+ __m128i ni6 = _mm256_castsi256_si128( i3 );
454
+ __m128i ni7 = _mm256_extractf128_si256( i3, 1);
455
+
456
+ // Compute the sum of the quants and set y[i].s
457
+ const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
458
+ const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
459
+ y[i].s = LM_GGML_CPU_FP32_TO_FP16(d * hsum_i32_4(_mm_add_epi32(s0, s1)));
460
+
461
+ // Convert int32 to int16
462
+ ni0 = _mm_packs_epi32( ni0, ni1 );
463
+ ni2 = _mm_packs_epi32( ni2, ni3 );
464
+ ni4 = _mm_packs_epi32( ni4, ni5 );
465
+ ni6 = _mm_packs_epi32( ni6, ni7 );
466
+ // Convert int16 to int8
467
+ ni0 = _mm_packs_epi16( ni0, ni2 );
468
+ ni4 = _mm_packs_epi16( ni4, ni6 );
469
+
470
+ _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
471
+ _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
472
+ #endif
473
+ }
474
+ #else
475
+ LM_GGML_UNUSED(nb);
476
+ // scalar
477
+ quantize_row_q8_1_ref(x, y, k);
478
+ #endif
479
+ }
480
+
481
+ // placeholder implementation for Apple targets
482
+ void quantize_row_q8_K(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k) {
483
+ quantize_row_q8_K_ref(x, y, k);
484
+ }
485
+
486
+ //===================================== Dot products =================================
487
+
488
+ //
489
+ // Helper functions
490
+ //
491
+
492
+ #if __AVX__ || __AVX2__ || __AVX512F__
493
+
494
+ // shuffles to pick the required scales in dot products
495
+ static inline __m256i get_scale_shuffle_q3k(int i) {
496
+ static const uint8_t k_shuffle[128] = {
497
+ 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
498
+ 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
499
+ 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
500
+ 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
501
+ };
502
+ return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
503
+ }
504
+ static inline __m256i get_scale_shuffle_k4(int i) {
505
+ static const uint8_t k_shuffle[256] = {
506
+ 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
507
+ 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
508
+ 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
509
+ 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
510
+ 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
511
+ 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
512
+ 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
513
+ 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
514
+ };
515
+ return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
516
+ }
517
+ static inline __m128i get_scale_shuffle(int i) {
518
+ static const uint8_t k_shuffle[128] = {
519
+ 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
520
+ 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
521
+ 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
522
+ 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
523
+ 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
524
+ 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11,
525
+ 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13,
526
+ 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15
527
+ };
528
+ return _mm_loadu_si128((const __m128i*)k_shuffle + i);
529
+ }
530
+ #endif
531
+
532
+ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
533
+ const int qk = QK8_0;
534
+ const int nb = n / qk;
535
+
536
+ assert(n % qk == 0);
537
+ assert(nrc == 1);
538
+ UNUSED(nrc);
539
+ UNUSED(bx);
540
+ UNUSED(by);
541
+ UNUSED(bs);
542
+
543
+ const block_q4_0 * LM_GGML_RESTRICT x = vx;
544
+ const block_q8_0 * LM_GGML_RESTRICT y = vy;
545
+
546
+ int ib = 0;
547
+ float sumf = 0;
548
+
549
+ #if defined(__AVX2__)
550
+ // Initialize accumulator with zeros
551
+ __m256 acc = _mm256_setzero_ps();
552
+
553
+ // Main loop
554
+ for (; ib < nb; ++ib) {
555
+ /* Compute combined scale for the block */
556
+ const __m256 d = _mm256_set1_ps( LM_GGML_CPU_FP16_TO_FP32(x[ib].d) * LM_GGML_CPU_FP16_TO_FP32(y[ib].d) );
557
+
558
+ __m256i qx = bytes_from_nibbles_32(x[ib].qs);
559
+
560
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
561
+ const __m256i off = _mm256_set1_epi8( 8 );
562
+ qx = _mm256_sub_epi8( qx, off );
563
+
564
+ __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
565
+
566
+ const __m256 q = mul_sum_i8_pairs_float(qx, qy);
567
+
568
+ /* Multiply q with scale and accumulate */
569
+ acc = _mm256_fmadd_ps( d, q, acc );
570
+ }
571
+
572
+ sumf = hsum_float_8(acc);
573
+ #elif defined(__AVX__)
574
+ __m256 accum = _mm256_setzero_ps();
575
+ for (; ib + 1 < nb; ib += 2) {
576
+ const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
577
+ const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
578
+ const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
579
+ const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
580
+ const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
581
+ const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
582
+
583
+ const __m128i q4b_1_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_1), _mm_set1_epi8(8));
584
+ const __m128i q4b_1_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_1, 4)), _mm_set1_epi8(8));
585
+ const __m128i q4b_2_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_2), _mm_set1_epi8(8));
586
+ const __m128i q4b_2_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_2, 4)), _mm_set1_epi8(8));
587
+
588
+ const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
589
+ const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
590
+ const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
591
+ const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
592
+ const __m128i p_1 = _mm_add_epi16(p16_1_0, p16_1_1);
593
+ const __m128i p_2 = _mm_add_epi16(p16_2_0, p16_2_1);
594
+ const __m256 p = sum_i16_pairs_float(p_2, p_1);
595
+
596
+ const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d);
597
+ accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
598
+ }
599
+
600
+ sumf = hsum_float_8(accum);
601
+ #elif defined(__SSSE3__)
602
+ // set constants
603
+ const __m128i lowMask = _mm_set1_epi8(0xF);
604
+ const __m128i off = _mm_set1_epi8(8);
605
+
606
+ // Initialize accumulator with zeros
607
+ __m128 acc_0 = _mm_setzero_ps();
608
+ __m128 acc_1 = _mm_setzero_ps();
609
+ __m128 acc_2 = _mm_setzero_ps();
610
+ __m128 acc_3 = _mm_setzero_ps();
611
+
612
+ for (; ib + 1 < nb; ib += 2) {
613
+ _mm_prefetch(&x[ib] + sizeof(block_q4_0), _MM_HINT_T0);
614
+ _mm_prefetch(&y[ib] + sizeof(block_q8_0), _MM_HINT_T0);
615
+
616
+ // Compute combined scale for the block 0 and 1
617
+ const __m128 d_0_1 = _mm_set1_ps( LM_GGML_CPU_FP16_TO_FP32(x[ib].d) * LM_GGML_CPU_FP16_TO_FP32(y[ib].d) );
618
+
619
+ const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[ib].qs);
620
+
621
+ __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
622
+ __m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs);
623
+ bx_0 = _mm_sub_epi8(bx_0, off);
624
+ const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
625
+
626
+ __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
627
+ __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16));
628
+ bx_1 = _mm_sub_epi8(bx_1, off);
629
+ const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
630
+
631
+ _mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
632
+ _mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
633
+
634
+ // Compute combined scale for the block 2 and 3
635
+ const __m128 d_2_3 = _mm_set1_ps( LM_GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * LM_GGML_CPU_FP16_TO_FP32(y[ib + 1].d) );
636
+
637
+ const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
638
+
639
+ __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
640
+ __m128i by_2 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
641
+ bx_2 = _mm_sub_epi8(bx_2, off);
642
+ const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
643
+
644
+ __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
645
+ __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[ib + 1].qs + 16));
646
+ bx_3 = _mm_sub_epi8(bx_3, off);
647
+ const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
648
+
649
+ // Convert int32_t to float
650
+ __m128 p0 = _mm_cvtepi32_ps(i32_0);
651
+ __m128 p1 = _mm_cvtepi32_ps(i32_1);
652
+ __m128 p2 = _mm_cvtepi32_ps(i32_2);
653
+ __m128 p3 = _mm_cvtepi32_ps(i32_3);
654
+
655
+ // Apply the scale
656
+ __m128 p0_d = _mm_mul_ps( d_0_1, p0 );
657
+ __m128 p1_d = _mm_mul_ps( d_0_1, p1 );
658
+ __m128 p2_d = _mm_mul_ps( d_2_3, p2 );
659
+ __m128 p3_d = _mm_mul_ps( d_2_3, p3 );
660
+
661
+ // Acummulate
662
+ acc_0 = _mm_add_ps(p0_d, acc_0);
663
+ acc_1 = _mm_add_ps(p1_d, acc_1);
664
+ acc_2 = _mm_add_ps(p2_d, acc_2);
665
+ acc_3 = _mm_add_ps(p3_d, acc_3);
666
+ }
667
+
668
+ sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
669
+
670
+ #endif
671
+ for (; ib < nb; ++ib) {
672
+ int sumi0 = 0;
673
+ int sumi1 = 0;
674
+
675
+ for (int j = 0; j < qk/2; ++j) {
676
+ const int v0 = (x[ib].qs[j] & 0x0F) - 8;
677
+ const int v1 = (x[ib].qs[j] >> 4) - 8;
678
+
679
+ sumi0 += (v0 * y[ib].qs[j]);
680
+ sumi1 += (v1 * y[ib].qs[j + qk/2]);
681
+ }
682
+
683
+ int sumi = sumi0 + sumi1;
684
+ sumf += sumi*LM_GGML_CPU_FP16_TO_FP32(x[ib].d)*LM_GGML_CPU_FP16_TO_FP32(y[ib].d);
685
+ }
686
+
687
+ *s = sumf;
688
+ }
689
+
690
+ void lm_ggml_vec_dot_q4_1_q8_1(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
691
+ const int qk = QK8_1;
692
+ const int nb = n / qk;
693
+
694
+ assert(n % qk == 0);
695
+ assert(nrc == 1);
696
+ UNUSED(nrc);
697
+ UNUSED(bx);
698
+ UNUSED(by);
699
+ UNUSED(bs);
700
+
701
+ const block_q4_1 * LM_GGML_RESTRICT x = vx;
702
+ const block_q8_1 * LM_GGML_RESTRICT y = vy;
703
+
704
+ int ib = 0;
705
+ float sumf = 0;
706
+
707
+ #if defined(__AVX2__) || defined(__AVX__)
708
+ // Initialize accumulator with zeros
709
+ __m256 acc = _mm256_setzero_ps();
710
+
711
+ float summs = 0;
712
+
713
+ // Main loop
714
+ for (; ib < nb; ++ib) {
715
+ const float d0 = LM_GGML_CPU_FP16_TO_FP32(x[ib].d);
716
+ const float d1 = LM_GGML_CPU_FP16_TO_FP32(y[ib].d);
717
+
718
+ summs += LM_GGML_CPU_FP16_TO_FP32(x[ib].m) * LM_GGML_CPU_FP16_TO_FP32(y[ib].s);
719
+
720
+ const __m256 d0v = _mm256_set1_ps( d0 );
721
+ const __m256 d1v = _mm256_set1_ps( d1 );
722
+
723
+ // Compute combined scales
724
+ const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
725
+
726
+ // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
727
+ const __m256i qx = bytes_from_nibbles_32(x[ib].qs);
728
+ const __m256i qy = _mm256_loadu_si256( (const __m256i *)y[ib].qs );
729
+
730
+ const __m256 xy = mul_sum_us8_pairs_float(qx, qy);
731
+
732
+ // Accumulate d0*d1*x*y
733
+ #if defined(__AVX2__)
734
+ acc = _mm256_fmadd_ps( d0d1, xy, acc );
735
+ #else
736
+ acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );
737
+ #endif
738
+ }
739
+
740
+ sumf = hsum_float_8(acc) + summs;
741
+
742
+ #endif
743
+ for (; ib < nb; ++ib) {
744
+ int sumi0 = 0;
745
+ int sumi1 = 0;
746
+
747
+ for (int j = 0; j < qk/2; ++j) {
748
+ const int v0 = (x[ib].qs[j] & 0x0F);
749
+ const int v1 = (x[ib].qs[j] >> 4);
750
+
751
+ sumi0 += (v0 * y[ib].qs[j]);
752
+ sumi1 += (v1 * y[ib].qs[j + qk/2]);
753
+ }
754
+
755
+ int sumi = sumi0 + sumi1;
756
+ sumf += (LM_GGML_CPU_FP16_TO_FP32(x[ib].d)*LM_GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + LM_GGML_CPU_FP16_TO_FP32(x[ib].m)*LM_GGML_CPU_FP16_TO_FP32(y[ib].s);
757
+ }
758
+
759
+ *s = sumf;
760
+ }
761
+
762
+ void lm_ggml_vec_dot_q5_0_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
763
+ const int qk = QK8_0;
764
+ const int nb = n / qk;
765
+
766
+ int ib = 0;
767
+ float sumf = 0;
768
+
769
+ assert(n % qk == 0);
770
+ assert(qk == QK5_0);
771
+ assert(nrc == 1);
772
+ UNUSED(nrc);
773
+ UNUSED(bx);
774
+ UNUSED(by);
775
+ UNUSED(bs);
776
+
777
+ const block_q5_0 * LM_GGML_RESTRICT x = vx;
778
+ const block_q8_0 * LM_GGML_RESTRICT y = vy;
779
+
780
+ #if defined(__AVX2__)
781
+ // Initialize accumulator with zeros
782
+ __m256 acc = _mm256_setzero_ps();
783
+
784
+ // Main loop
785
+ for (; ib < nb; ++ib) {
786
+ /* Compute combined scale for the block */
787
+ const __m256 d = _mm256_set1_ps(LM_GGML_CPU_FP16_TO_FP32(x[ib].d) * LM_GGML_CPU_FP16_TO_FP32(y[ib].d));
788
+
789
+ __m256i qx = bytes_from_nibbles_32(x[ib].qs);
790
+ __m256i bxhi = bytes_from_bits_32(x[ib].qh);
791
+ bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
792
+ qx = _mm256_or_si256(qx, bxhi);
793
+
794
+ __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
795
+
796
+ const __m256 q = mul_sum_i8_pairs_float(qx, qy);
797
+
798
+ /* Multiply q with scale and accumulate */
799
+ acc = _mm256_fmadd_ps(d, q, acc);
800
+ }
801
+
802
+ sumf = hsum_float_8(acc);
803
+ #elif defined(__AVX__)
804
+ // Initialize accumulator with zeros
805
+ __m256 acc = _mm256_setzero_ps();
806
+ __m128i mask = _mm_set1_epi8((char)0xF0);
807
+
808
+ // Main loop
809
+ for (; ib < nb; ++ib) {
810
+ /* Compute combined scale for the block */
811
+ const __m256 d = _mm256_set1_ps(LM_GGML_CPU_FP16_TO_FP32(x[ib].d) * LM_GGML_CPU_FP16_TO_FP32(y[ib].d));
812
+
813
+ __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs);
814
+ const __m256i bxhi = bytes_from_bits_32(x[ib].qh);
815
+ __m128i bxhil = _mm256_castsi256_si128(bxhi);
816
+ __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
817
+ bxhil = _mm_andnot_si128(bxhil, mask);
818
+ bxhih = _mm_andnot_si128(bxhih, mask);
819
+ __m128i bxl = _mm256_castsi256_si128(bx_0);
820
+ __m128i bxh = _mm256_extractf128_si256(bx_0, 1);
821
+ bxl = _mm_or_si128(bxl, bxhil);
822
+ bxh = _mm_or_si128(bxh, bxhih);
823
+ bx_0 = MM256_SET_M128I(bxh, bxl);
824
+
825
+ const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs);
826
+
827
+ const __m256 q = mul_sum_i8_pairs_float(bx_0, by_0);
828
+
829
+ /* Multiply q with scale and accumulate */
830
+ acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
831
+ }
832
+
833
+ sumf = hsum_float_8(acc);
834
+
835
+ #endif
836
+ for (; ib < nb; ++ib) {
837
+ uint32_t qh;
838
+ memcpy(&qh, x[ib].qh, sizeof(qh));
839
+
840
+ int sumi0 = 0;
841
+ int sumi1 = 0;
842
+
843
+ for (int j = 0; j < qk/2; ++j) {
844
+ const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
845
+ const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
846
+
847
+ const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16);
848
+ const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16);
849
+
850
+ sumi0 += (x0 * y[ib].qs[j]);
851
+ sumi1 += (x1 * y[ib].qs[j + qk/2]);
852
+ }
853
+
854
+ int sumi = sumi0 + sumi1;
855
+ sumf += (LM_GGML_CPU_FP16_TO_FP32(x[ib].d)*LM_GGML_CPU_FP16_TO_FP32(y[ib].d)) * sumi;
856
+ }
857
+
858
+ *s = sumf;
859
+ }
860
+
861
+ void lm_ggml_vec_dot_q5_1_q8_1(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
862
+ const int qk = QK8_1;
863
+ const int nb = n / qk;
864
+
865
+ int ib = 0;
866
+ float sumf = 0;
867
+
868
+ assert(n % qk == 0);
869
+ assert(qk == QK5_1);
870
+ assert(nrc == 1);
871
+ UNUSED(nrc);
872
+ UNUSED(bx);
873
+ UNUSED(by);
874
+ UNUSED(bs);
875
+
876
+ const block_q5_1 * LM_GGML_RESTRICT x = vx;
877
+ const block_q8_1 * LM_GGML_RESTRICT y = vy;
878
+
879
+ #if defined(__AVX2__)
880
+ // Initialize accumulator with zeros
881
+ __m256 acc = _mm256_setzero_ps();
882
+
883
+ float summs = 0.0f;
884
+
885
+ // Main loop
886
+ for (; ib < nb; ++ib) {
887
+ const __m256 dx = _mm256_set1_ps(LM_GGML_CPU_FP16_TO_FP32(x[ib].d));
888
+
889
+ summs += LM_GGML_CPU_FP16_TO_FP32(x[ib].m) * LM_GGML_CPU_FP16_TO_FP32(y[ib].s);
890
+
891
+ __m256i qx = bytes_from_nibbles_32(x[ib].qs);
892
+ __m256i bxhi = bytes_from_bits_32(x[ib].qh);
893
+ bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
894
+ qx = _mm256_or_si256(qx, bxhi);
895
+
896
+ const __m256 dy = _mm256_set1_ps(LM_GGML_CPU_FP16_TO_FP32(y[ib].d));
897
+ const __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
898
+
899
+ const __m256 q = mul_sum_us8_pairs_float(qx, qy);
900
+
901
+ acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
902
+ }
903
+
904
+ sumf = hsum_float_8(acc) + summs;
905
+ #elif defined(__AVX__)
906
+ // Initialize accumulator with zeros
907
+ __m256 acc = _mm256_setzero_ps();
908
+ __m128i mask = _mm_set1_epi8(0x10);
909
+
910
+ float summs = 0.0f;
911
+
912
+ // Main loop
913
+ for (; ib < nb; ++ib) {
914
+ const __m256 dx = _mm256_set1_ps(LM_GGML_CPU_FP16_TO_FP32(x[ib].d));
915
+
916
+ summs += LM_GGML_CPU_FP16_TO_FP32(x[ib].m) * LM_GGML_CPU_FP16_TO_FP32(y[ib].s);
917
+
918
+ __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs);
919
+ const __m256i bxhi = bytes_from_bits_32(x[ib].qh);
920
+ __m128i bxhil = _mm256_castsi256_si128(bxhi);
921
+ __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
922
+ bxhil = _mm_and_si128(bxhil, mask);
923
+ bxhih = _mm_and_si128(bxhih, mask);
924
+ __m128i bxl = _mm256_castsi256_si128(bx_0);
925
+ __m128i bxh = _mm256_extractf128_si256(bx_0, 1);
926
+ bxl = _mm_or_si128(bxl, bxhil);
927
+ bxh = _mm_or_si128(bxh, bxhih);
928
+ bx_0 = MM256_SET_M128I(bxh, bxl);
929
+
930
+ const __m256 dy = _mm256_set1_ps(LM_GGML_CPU_FP16_TO_FP32(y[ib].d));
931
+ const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs);
932
+
933
+ const __m256 q = mul_sum_us8_pairs_float(bx_0, by_0);
934
+
935
+ acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
936
+ }
937
+
938
+ sumf = hsum_float_8(acc) + summs;
939
+
940
+ #endif
941
+ for (; ib < nb; ++ib) {
942
+ uint32_t qh;
943
+ memcpy(&qh, x[ib].qh, sizeof(qh));
944
+
945
+ int sumi0 = 0;
946
+ int sumi1 = 0;
947
+
948
+ for (int j = 0; j < qk/2; ++j) {
949
+ const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
950
+ const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
951
+
952
+ const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0;
953
+ const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1;
954
+
955
+ sumi0 += (x0 * y[ib].qs[j]);
956
+ sumi1 += (x1 * y[ib].qs[j + qk/2]);
957
+ }
958
+
959
+ int sumi = sumi0 + sumi1;
960
+ sumf += (LM_GGML_CPU_FP16_TO_FP32(x[ib].d)*LM_GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + LM_GGML_CPU_FP16_TO_FP32(x[ib].m)*LM_GGML_CPU_FP16_TO_FP32(y[ib].s);
961
+ }
962
+
963
+ *s = sumf;
964
+ }
965
+
966
+ void lm_ggml_vec_dot_q8_0_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
967
+ const int qk = QK8_0;
968
+ const int nb = n / qk;
969
+
970
+ assert(n % qk == 0);
971
+ assert(nrc == 1);
972
+ UNUSED(nrc);
973
+ UNUSED(bx);
974
+ UNUSED(by);
975
+ UNUSED(bs);
976
+
977
+ const block_q8_0 * LM_GGML_RESTRICT x = vx;
978
+ const block_q8_0 * LM_GGML_RESTRICT y = vy;
979
+
980
+ int ib = 0;
981
+ float sumf = 0;
982
+
983
+ #if defined(__AVX2__)
984
+ // Initialize accumulator with zeros
985
+ __m256 acc = _mm256_setzero_ps();
986
+
987
+ // Main loop
988
+ for (; ib < nb; ++ib) {
989
+ // Compute combined scale for the block
990
+ const __m256 d = _mm256_set1_ps(LM_GGML_CPU_FP16_TO_FP32(x[ib].d) * LM_GGML_CPU_FP16_TO_FP32(y[ib].d));
991
+ __m256i qx = _mm256_loadu_si256((const __m256i *)x[ib].qs);
992
+ __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
993
+
994
+ const __m256 q = mul_sum_i8_pairs_float(qx, qy);
995
+
996
+ // Multiply q with scale and accumulate
997
+ acc = _mm256_fmadd_ps( d, q, acc );
998
+ }
999
+
1000
+ sumf = hsum_float_8(acc);
1001
+ #elif defined(__AVX__)
1002
+ __m256 accum = _mm256_setzero_ps();
1003
+
1004
+ for (; ib + 1 < nb; ib += 2) {
1005
+ const __m128i qx_1_0 = _mm_loadu_si128((const __m128i *)x[ib].qs);
1006
+ const __m128i qx_1_1 = _mm_loadu_si128((const __m128i *)x[ib].qs + 1);
1007
+ const __m128i qx_2_0 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
1008
+ const __m128i qx_2_1 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs + 1);
1009
+ const __m128i qy_1_0 = _mm_loadu_si128((const __m128i *)y[ib].qs);
1010
+ const __m128i qy_1_1 = _mm_loadu_si128((const __m128i *)y[ib].qs + 1);
1011
+ const __m128i qy_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
1012
+ const __m128i qy_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
1013
+
1014
+ const __m256 p = mul_sum_i8_quad_float(qx_1_0, qx_1_1, qx_2_0, qx_2_1, qy_1_0, qy_1_1, qy_2_0, qy_2_1);
1015
+ const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d);
1016
+ accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
1017
+ }
1018
+
1019
+ sumf = hsum_float_8(accum);
1020
+
1021
+ #endif
1022
+ for (; ib < nb; ++ib) {
1023
+ int sumi = 0;
1024
+
1025
+ for (int j = 0; j < qk; j++) {
1026
+ sumi += x[ib].qs[j]*y[ib].qs[j];
1027
+ }
1028
+
1029
+ sumf += sumi*(LM_GGML_CPU_FP16_TO_FP32(x[ib].d)*LM_GGML_CPU_FP16_TO_FP32(y[ib].d));
1030
+ }
1031
+
1032
+ *s = sumf;
1033
+ }
1034
+
1035
+ void lm_ggml_vec_dot_tq1_0_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
1036
+ assert(nrc == 1);
1037
+ UNUSED(nrc);
1038
+ UNUSED(bx);
1039
+ UNUSED(by);
1040
+ UNUSED(bs);
1041
+
1042
+ const block_tq1_0 * LM_GGML_RESTRICT x = vx;
1043
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
1044
+
1045
+ const int nb = n / QK_K;
1046
+
1047
+ #if defined(__AVX2__)
1048
+ __m256 sumf = _mm256_setzero_ps();
1049
+
1050
+ for (int i = 0; i < nb; ++i) {
1051
+ // 16-bit sums
1052
+ __m256i sumi0 = _mm256_setzero_si256();
1053
+ __m256i sumi1 = _mm256_setzero_si256();
1054
+ __m256i sumi2 = _mm256_setzero_si256();
1055
+
1056
+ // first 32 bytes of 5 elements
1057
+ {
1058
+ __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs));
1059
+ // 8-bit multiplies with shifts, masks and adds
1060
+ __m256i qx1 = _mm256_add_epi8(qx0, _mm256_add_epi8(qx0, qx0)); // 1 * 3
1061
+ __m256i qx2 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx0, 3), _mm256_set1_epi8(-8)), qx0); // 1 * 9
1062
+ __m256i qx3 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx1, 3), _mm256_set1_epi8(-8)), qx1); // 3 * 9
1063
+ __m256i qx4 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx2, 3), _mm256_set1_epi8(-8)), qx2); // 9 * 9
1064
+
1065
+ // TODO: can _mm256_mulhi_epu16 be faster even if 16-bits?
1066
+
1067
+ // Cancel the +1 from avg so that it behaves like a halving add
1068
+ qx0 = _mm256_subs_epu8(qx0, _mm256_set1_epi8(1));
1069
+ qx1 = _mm256_subs_epu8(qx1, _mm256_set1_epi8(1));
1070
+ qx2 = _mm256_subs_epu8(qx2, _mm256_set1_epi8(1));
1071
+ qx3 = _mm256_subs_epu8(qx3, _mm256_set1_epi8(1));
1072
+ qx4 = _mm256_subs_epu8(qx4, _mm256_set1_epi8(1));
1073
+ // Multiply by 3 and get the top 2 bits
1074
+ qx0 = _mm256_avg_epu8(qx0, _mm256_avg_epu8(qx0, _mm256_setzero_si256()));
1075
+ qx1 = _mm256_avg_epu8(qx1, _mm256_avg_epu8(qx1, _mm256_setzero_si256()));
1076
+ qx2 = _mm256_avg_epu8(qx2, _mm256_avg_epu8(qx2, _mm256_setzero_si256()));
1077
+ qx3 = _mm256_avg_epu8(qx3, _mm256_avg_epu8(qx3, _mm256_setzero_si256()));
1078
+ qx4 = _mm256_avg_epu8(qx4, _mm256_avg_epu8(qx4, _mm256_setzero_si256()));
1079
+ qx0 = _mm256_and_si256(_mm256_srli_epi16(qx0, 6), _mm256_set1_epi8(3));
1080
+ qx1 = _mm256_and_si256(_mm256_srli_epi16(qx1, 6), _mm256_set1_epi8(3));
1081
+ qx2 = _mm256_and_si256(_mm256_srli_epi16(qx2, 6), _mm256_set1_epi8(3));
1082
+ qx3 = _mm256_and_si256(_mm256_srli_epi16(qx3, 6), _mm256_set1_epi8(3));
1083
+ qx4 = _mm256_and_si256(_mm256_srli_epi16(qx4, 6), _mm256_set1_epi8(3));
1084
+
1085
+ const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 0));
1086
+ const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 32));
1087
+ const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 64));
1088
+ const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 96));
1089
+ const __m256i qy4 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 128));
1090
+
1091
+ qx0 = _mm256_maddubs_epi16(qx0, qy0);
1092
+ qx1 = _mm256_maddubs_epi16(qx1, qy1);
1093
+ qx2 = _mm256_maddubs_epi16(qx2, qy2);
1094
+ qx3 = _mm256_maddubs_epi16(qx3, qy3);
1095
+ qx4 = _mm256_maddubs_epi16(qx4, qy4);
1096
+
1097
+ sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1));
1098
+ sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3));
1099
+ sumi2 = _mm256_add_epi16(sumi2, qx4);
1100
+ }
1101
+
1102
+ // last 16 bytes of 5-element, along with the 4 bytes of 4 elements
1103
+ {
1104
+ __m128i qx0 = _mm_loadu_si128((const __m128i *) (x[i].qs + 32));
1105
+ uint32_t qh;
1106
+ memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
1107
+ __m256i qx5_l = _mm256_cvtepu8_epi16(_mm_set1_epi32(qh));
1108
+ __m128i qx1 = _mm_add_epi8(qx0, _mm_add_epi8(qx0, qx0)); // 1 * 3
1109
+ __m128i qx2 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx0, 3), _mm_set1_epi8(-8)), qx0); // 1 * 9
1110
+ __m128i qx3 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx1, 3), _mm_set1_epi8(-8)), qx1); // 3 * 9
1111
+ __m128i qx4 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx2, 3), _mm_set1_epi8(-8)), qx2); // 9 * 9
1112
+ __m256i qx01 = MM256_SET_M128I(qx1, qx0);
1113
+ __m256i qx23 = MM256_SET_M128I(qx3, qx2);
1114
+
1115
+ // avx2 does not have 8-bit multiplies, so 16-bit it is.
1116
+ qx5_l = _mm256_mullo_epi16(qx5_l, _mm256_set_epi16(27, 27, 27, 27, 9, 9, 9, 9, 3, 3, 3, 3, 1, 1, 1, 1));
1117
+ qx5_l = _mm256_and_si256(qx5_l, _mm256_set1_epi16(0xFF));
1118
+ __m128i qx5 = _mm_packus_epi16(_mm256_castsi256_si128(qx5_l), _mm256_extracti128_si256(qx5_l, 1));
1119
+
1120
+ __m256i qx45 = MM256_SET_M128I(qx5, qx4);
1121
+
1122
+ // Cancel the +1 from avg so that it behaves like a halving add
1123
+ qx01 = _mm256_subs_epu8(qx01, _mm256_set1_epi8(1));
1124
+ qx23 = _mm256_subs_epu8(qx23, _mm256_set1_epi8(1));
1125
+ qx45 = _mm256_subs_epu8(qx45, _mm256_set1_epi8(1));
1126
+ // Multiply by 3 and get the top 2 bits
1127
+ qx01 = _mm256_avg_epu8(qx01, _mm256_avg_epu8(qx01, _mm256_setzero_si256()));
1128
+ qx23 = _mm256_avg_epu8(qx23, _mm256_avg_epu8(qx23, _mm256_setzero_si256()));
1129
+ qx45 = _mm256_avg_epu8(qx45, _mm256_avg_epu8(qx45, _mm256_setzero_si256()));
1130
+ qx01 = _mm256_and_si256(_mm256_srli_epi16(qx01, 6), _mm256_set1_epi8(3));
1131
+ qx23 = _mm256_and_si256(_mm256_srli_epi16(qx23, 6), _mm256_set1_epi8(3));
1132
+ qx45 = _mm256_and_si256(_mm256_srli_epi16(qx45, 6), _mm256_set1_epi8(3));
1133
+
1134
+ const __m256i qy01 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 160));
1135
+ const __m256i qy23 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 192));
1136
+ const __m256i qy45 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 224));
1137
+
1138
+ qx01 = _mm256_maddubs_epi16(qx01, qy01);
1139
+ qx23 = _mm256_maddubs_epi16(qx23, qy23);
1140
+ qx45 = _mm256_maddubs_epi16(qx45, qy45);
1141
+
1142
+ sumi0 = _mm256_add_epi16(sumi0, qx01);
1143
+ sumi1 = _mm256_add_epi16(sumi1, qx23);
1144
+ sumi2 = _mm256_add_epi16(sumi2, qx45);
1145
+ }
1146
+
1147
+ const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums);
1148
+ const __m256 d = _mm256_set1_ps(y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d));
1149
+
1150
+ sumi0 = _mm256_sub_epi16(sumi0, ysum);
1151
+ sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(sumi1, sumi2));
1152
+ sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1));
1153
+
1154
+ sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf);
1155
+ }
1156
+
1157
+ *s = hsum_float_8(sumf);
1158
+
1159
+ #else
1160
+ const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
1161
+
1162
+ float sumf = 0.0f;
1163
+
1164
+ for (int i = 0; i < nb; ++i) {
1165
+ int sum = 0;
1166
+
1167
+ for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) {
1168
+ for (size_t l = 0; l < 5; ++l) {
1169
+ for (size_t m = 0; m < 32; ++m) {
1170
+ uint8_t q = x[i].qs[j + m] * pow3[l];
1171
+ uint16_t xi = ((uint16_t) q * 3) >> 8;
1172
+ sum += (xi - 1) * y[i].qs[j*5 + l*32 + m];
1173
+ }
1174
+ }
1175
+ }
1176
+ for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) {
1177
+ for (size_t l = 0; l < 5; ++l) {
1178
+ for (size_t m = 0; m < 16; ++m) {
1179
+ uint8_t q = x[i].qs[j + m] * pow3[l];
1180
+ uint16_t xi = ((uint16_t) q * 3) >> 8;
1181
+ sum += (xi - 1) * y[i].qs[j*5 + l*16 + m];
1182
+ }
1183
+ }
1184
+ }
1185
+
1186
+ for (size_t l = 0; l < 4; ++l) {
1187
+ for (size_t j = 0; j < sizeof(x->qh); ++j) {
1188
+ uint8_t q = x[i].qh[j] * pow3[l];
1189
+ uint16_t xi = ((uint16_t) q * 3) >> 8;
1190
+ sum += (xi - 1) * y[i].qs[sizeof(x->qs)*5 + l*sizeof(x->qh) + j];
1191
+ }
1192
+ }
1193
+
1194
+ sumf += (float) sum * (LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d);
1195
+ }
1196
+
1197
+ *s = sumf;
1198
+ #endif
1199
+ }
1200
+
1201
+ void lm_ggml_vec_dot_tq2_0_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
1202
+ assert(nrc == 1);
1203
+ UNUSED(nrc);
1204
+ UNUSED(bx);
1205
+ UNUSED(by);
1206
+ UNUSED(bs);
1207
+
1208
+ const block_tq2_0 * LM_GGML_RESTRICT x = vx;
1209
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
1210
+
1211
+ const int nb = n / QK_K;
1212
+
1213
+ #if defined(__AVX2__)
1214
+ __m256 sumf = _mm256_setzero_ps();
1215
+
1216
+ for (int i = 0; i < nb; ++i) {
1217
+ // 16-bit sums, because 256*127 still fits
1218
+ __m256i sumi0 = _mm256_setzero_si256();
1219
+ __m256i sumi1 = _mm256_setzero_si256();
1220
+
1221
+ for (size_t j = 0; j < sizeof(x->qs); j += 32) {
1222
+ __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs + j));
1223
+ __m256i qx1 = _mm256_srli_epi16(qx0, 2);
1224
+ __m256i qx2 = _mm256_srli_epi16(qx0, 4);
1225
+ __m256i qx3 = _mm256_srli_epi16(qx0, 6);
1226
+
1227
+ // 0, 1, 2 (should not be 3)
1228
+ qx0 = _mm256_and_si256(qx0, _mm256_set1_epi8(3));
1229
+ qx1 = _mm256_and_si256(qx1, _mm256_set1_epi8(3));
1230
+ qx2 = _mm256_and_si256(qx2, _mm256_set1_epi8(3));
1231
+ qx3 = _mm256_and_si256(qx3, _mm256_set1_epi8(3));
1232
+
1233
+ const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 0));
1234
+ const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 32));
1235
+ const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 64));
1236
+ const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 96));
1237
+
1238
+ qx0 = _mm256_maddubs_epi16(qx0, qy0);
1239
+ qx1 = _mm256_maddubs_epi16(qx1, qy1);
1240
+ qx2 = _mm256_maddubs_epi16(qx2, qy2);
1241
+ qx3 = _mm256_maddubs_epi16(qx3, qy3);
1242
+
1243
+ sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1));
1244
+ sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3));
1245
+ }
1246
+
1247
+ const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums);
1248
+ const __m256 d = _mm256_set1_ps(y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d));
1249
+
1250
+ sumi0 = _mm256_add_epi16(sumi0, sumi1);
1251
+ sumi0 = _mm256_sub_epi16(sumi0, ysum);
1252
+ sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1));
1253
+
1254
+ sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf);
1255
+ }
1256
+
1257
+ *s = hsum_float_8(sumf);
1258
+
1259
+ #else
1260
+ float sumf = 0.0f;
1261
+
1262
+ for (int i = 0; i < nb; ++i) {
1263
+ int32_t sumi = 0;
1264
+
1265
+ for (size_t j = 0; j < sizeof(x->qs); j += 32) {
1266
+ for (size_t l = 0; l < 4; ++l) {
1267
+ for (size_t k = 0; k < 32; ++k) {
1268
+ sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].qs[j + k] >> (l*2)) & 3) - 1);
1269
+ }
1270
+ }
1271
+ }
1272
+
1273
+ const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
1274
+
1275
+ sumf += (float) sumi * d;
1276
+ }
1277
+
1278
+ *s = sumf;
1279
+ #endif
1280
+ }
1281
+
1282
+ void lm_ggml_vec_dot_q2_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
1283
+ assert(nrc == 1);
1284
+ UNUSED(nrc);
1285
+ UNUSED(bx);
1286
+ UNUSED(by);
1287
+ UNUSED(bs);
1288
+
1289
+ const block_q2_K * LM_GGML_RESTRICT x = vx;
1290
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
1291
+
1292
+ const int nb = n / QK_K;
1293
+
1294
+ #if defined __AVX2__
1295
+
1296
+ const __m256i m3 = _mm256_set1_epi8(3);
1297
+ const __m128i m4 = _mm_set1_epi8(0xF);
1298
+
1299
+ __m256 acc = _mm256_setzero_ps();
1300
+
1301
+ for (int i = 0; i < nb; ++i) {
1302
+
1303
+ const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
1304
+ const float dmin = -y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].dmin);
1305
+
1306
+ const uint8_t * LM_GGML_RESTRICT q2 = x[i].qs;
1307
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
1308
+
1309
+ const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
1310
+ const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);
1311
+ const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
1312
+ const __m256i mins = _mm256_cvtepi8_epi16(mins8);
1313
+ const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums));
1314
+
1315
+ acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc);
1316
+
1317
+ const __m256i all_scales = _mm256_cvtepi8_epi16(scales8);
1318
+ const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
1319
+ const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
1320
+ const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
1321
+
1322
+ __m256i sumi = _mm256_setzero_si256();
1323
+
1324
+ for (int j = 0; j < QK_K/128; ++j) {
1325
+
1326
+ const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32;
1327
+
1328
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1329
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1330
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1331
+ const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1332
+
1333
+ const __m256i q2_0 = _mm256_and_si256(q2bits, m3);
1334
+ const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3);
1335
+ const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3);
1336
+ const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3);
1337
+
1338
+ __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0);
1339
+ __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1);
1340
+ __m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2);
1341
+ __m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3);
1342
+
1343
+ p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0);
1344
+ p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1);
1345
+ p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2);
1346
+ p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3);
1347
+
1348
+ p0 = _mm256_add_epi32(p0, p1);
1349
+ p2 = _mm256_add_epi32(p2, p3);
1350
+
1351
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2));
1352
+ }
1353
+
1354
+ acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
1355
+
1356
+ }
1357
+
1358
+ *s = hsum_float_8(acc);
1359
+
1360
+ #elif defined __AVX__
1361
+
1362
+ const __m128i m3 = _mm_set1_epi8(0x3);
1363
+ const __m128i m4 = _mm_set1_epi8(0xF);
1364
+ const __m128i m2 = _mm_set1_epi8(0x2);
1365
+
1366
+ __m256 acc = _mm256_setzero_ps();
1367
+
1368
+ for (int i = 0; i < nb; ++i) {
1369
+
1370
+ const float dall = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
1371
+ const float dmin = -y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].dmin);
1372
+
1373
+ const uint8_t * LM_GGML_RESTRICT q2 = x[i].qs;
1374
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
1375
+
1376
+ // load mins and scales from block_q2_K.scales[QK_K/16]
1377
+ const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
1378
+ const __m128i scales16 = _mm_and_si128(mins_and_scales, m4);
1379
+ const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
1380
+ const __m128i mins_0 = _mm_cvtepi8_epi16(mins16);
1381
+ const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16));
1382
+
1383
+ // summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2
1384
+ const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0]));
1385
+ const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8]));
1386
+
1387
+ // sumf += -dmin * summs in 32bits*8
1388
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc);
1389
+
1390
+ const __m128i scales_0 = _mm_cvtepi8_epi16(scales16);
1391
+ const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16));
1392
+ const __m128i scales[2] = { scales_0, scales_1 };
1393
+
1394
+ __m128i sumi_0 = _mm_setzero_si128();
1395
+ __m128i sumi_1 = _mm_setzero_si128();
1396
+
1397
+ for (int j = 0; j < QK_K/128; ++j) {
1398
+
1399
+ // load Q8 quants int8*16*8 from block_q8_K.qs[QK_K]
1400
+ const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1401
+ const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1402
+ const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1403
+ const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1404
+ const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1405
+ const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1406
+ const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1407
+ const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1408
+
1409
+ // load 2bits*16*8 from block_q2_K.qs[QK_K/4]
1410
+ __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
1411
+ const __m128i q2_0 = _mm_and_si128(q2bits, m3);
1412
+ const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
1413
+ const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
1414
+ const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
1415
+ q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
1416
+ const __m128i q2_1 = _mm_and_si128(q2bits, m3);
1417
+ const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
1418
+ const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
1419
+ const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
1420
+
1421
+ // isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8
1422
+ __m128i p0 = _mm_maddubs_epi16(q2_0, q8_0);
1423
+ __m128i p1 = _mm_maddubs_epi16(q2_1, q8_1);
1424
+ __m128i p2 = _mm_maddubs_epi16(q2_2, q8_2);
1425
+ __m128i p3 = _mm_maddubs_epi16(q2_3, q8_3);
1426
+ __m128i p4 = _mm_maddubs_epi16(q2_4, q8_4);
1427
+ __m128i p5 = _mm_maddubs_epi16(q2_5, q8_5);
1428
+ __m128i p6 = _mm_maddubs_epi16(q2_6, q8_6);
1429
+ __m128i p7 = _mm_maddubs_epi16(q2_7, q8_7);
1430
+
1431
+ // isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8
1432
+ __m128i shuffle = _mm_set1_epi16(0x0100);
1433
+ p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0);
1434
+ shuffle = _mm_add_epi16(shuffle, m2);
1435
+ p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1);
1436
+ shuffle = _mm_add_epi16(shuffle, m2);
1437
+ p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2);
1438
+ shuffle = _mm_add_epi16(shuffle, m2);
1439
+ p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3);
1440
+ shuffle = _mm_add_epi16(shuffle, m2);
1441
+ p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4);
1442
+ shuffle = _mm_add_epi16(shuffle, m2);
1443
+ p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5);
1444
+ shuffle = _mm_add_epi16(shuffle, m2);
1445
+ p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6);
1446
+ shuffle = _mm_add_epi16(shuffle, m2);
1447
+ p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7);
1448
+
1449
+ p0 = _mm_add_epi32(p0, p1);
1450
+ p2 = _mm_add_epi32(p2, p3);
1451
+ p4 = _mm_add_epi32(p4, p5);
1452
+ p6 = _mm_add_epi32(p6, p7);
1453
+
1454
+ // isum in 32bits*4*2
1455
+ sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2));
1456
+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6));
1457
+ }
1458
+
1459
+ // sumf += dall * isum - dmin * summs in 32bits
1460
+ __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
1461
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc);
1462
+ }
1463
+
1464
+ *s = hsum_float_8(acc);
1465
+
1466
+ #else
1467
+
1468
+ float sumf = 0;
1469
+
1470
+ for (int i = 0; i < nb; ++i) {
1471
+
1472
+ const uint8_t * q2 = x[i].qs;
1473
+ const int8_t * q8 = y[i].qs;
1474
+ const uint8_t * sc = x[i].scales;
1475
+
1476
+ int summs = 0;
1477
+ for (int j = 0; j < 16; ++j) {
1478
+ summs += y[i].bsums[j] * (sc[j] >> 4);
1479
+ }
1480
+
1481
+ const float dall = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
1482
+ const float dmin = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].dmin);
1483
+
1484
+ int isum = 0;
1485
+ int is = 0;
1486
+ int d;
1487
+ for (int k = 0; k < QK_K/128; ++k) {
1488
+ int shift = 0;
1489
+ for (int j = 0; j < 4; ++j) {
1490
+ d = sc[is++] & 0xF;
1491
+ int isuml = 0;
1492
+ for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
1493
+ isum += d * isuml;
1494
+ d = sc[is++] & 0xF;
1495
+ isuml = 0;
1496
+ for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
1497
+ isum += d * isuml;
1498
+ shift += 2;
1499
+ q8 += 32;
1500
+ }
1501
+ q2 += 32;
1502
+ }
1503
+ sumf += dall * isum - dmin * summs;
1504
+ }
1505
+ *s = sumf;
1506
+ #endif
1507
+ }
1508
+
1509
+ void lm_ggml_vec_dot_q3_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
1510
+ assert(n % QK_K == 0);
1511
+ assert(nrc == 1);
1512
+ UNUSED(nrc);
1513
+ UNUSED(bx);
1514
+ UNUSED(by);
1515
+ UNUSED(bs);
1516
+
1517
+ const uint32_t kmask1 = 0x03030303;
1518
+ const uint32_t kmask2 = 0x0f0f0f0f;
1519
+
1520
+ const block_q3_K * LM_GGML_RESTRICT x = vx;
1521
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
1522
+
1523
+ const int nb = n / QK_K;
1524
+
1525
+ #if defined __AVX2__
1526
+
1527
+ const __m256i m3 = _mm256_set1_epi8(3);
1528
+ const __m256i mone = _mm256_set1_epi8(1);
1529
+ const __m128i m32 = _mm_set1_epi8(32);
1530
+
1531
+ __m256 acc = _mm256_setzero_ps();
1532
+
1533
+ uint32_t aux[3];
1534
+
1535
+ for (int i = 0; i < nb; ++i) {
1536
+
1537
+ const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
1538
+
1539
+ const uint8_t * LM_GGML_RESTRICT q3 = x[i].qs;
1540
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
1541
+
1542
+ // Set up scales
1543
+ memcpy(aux, x[i].scales, 12);
1544
+ __m128i scales128 = _mm_set_epi32(
1545
+ ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
1546
+ ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
1547
+ (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
1548
+ (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
1549
+ scales128 = _mm_sub_epi8(scales128, m32);
1550
+ const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);
1551
+ const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
1552
+ const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
1553
+ const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
1554
+
1555
+ // high bit
1556
+ const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask);
1557
+
1558
+ // integer accumulator
1559
+ __m256i sumi = _mm256_setzero_si256();
1560
+
1561
+ int bit = 0;
1562
+ int is = 0;
1563
+
1564
+ for (int j = 0; j < QK_K/128; ++j) {
1565
+ // load low 2 bits
1566
+ const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32;
1567
+
1568
+ // prepare low and high bits
1569
+ const __m256i q3l_0 = _mm256_and_si256(q3bits, m3);
1570
+ const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
1571
+ ++bit;
1572
+
1573
+ const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3);
1574
+ const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
1575
+ ++bit;
1576
+
1577
+ const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3);
1578
+ const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
1579
+ ++bit;
1580
+
1581
+ const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3);
1582
+ const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
1583
+ ++bit;
1584
+
1585
+ // load Q8 quants
1586
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1587
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1588
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1589
+ const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1590
+
1591
+ // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
1592
+ // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
1593
+ // and 2 if the high bit was set)
1594
+ __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
1595
+ __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
1596
+ __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2);
1597
+ __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3);
1598
+
1599
+ __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
1600
+ __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
1601
+ __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2);
1602
+ __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3);
1603
+
1604
+ p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
1605
+ p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
1606
+ p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
1607
+ p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
1608
+
1609
+ // multiply with scales
1610
+ p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0);
1611
+ p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1);
1612
+ p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2);
1613
+ p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3);
1614
+
1615
+ // accumulate
1616
+ p16_0 = _mm256_add_epi32(p16_0, p16_1);
1617
+ p16_2 = _mm256_add_epi32(p16_2, p16_3);
1618
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2));
1619
+
1620
+ }
1621
+
1622
+ // multiply with block scale and accumulate
1623
+ acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
1624
+
1625
+ }
1626
+
1627
+ *s = hsum_float_8(acc);
1628
+
1629
+ #elif defined __AVX__
1630
+
1631
+ const __m128i m3 = _mm_set1_epi8(3);
1632
+ const __m128i mone = _mm_set1_epi8(1);
1633
+ const __m128i m32 = _mm_set1_epi8(32);
1634
+ const __m128i m2 = _mm_set1_epi8(2);
1635
+
1636
+ __m256 acc = _mm256_setzero_ps();
1637
+
1638
+ const uint32_t *aux;
1639
+
1640
+ for (int i = 0; i < nb; ++i) {
1641
+
1642
+ const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
1643
+
1644
+ const uint8_t * LM_GGML_RESTRICT q3 = x[i].qs;
1645
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
1646
+
1647
+ // Set up scales
1648
+ aux = (const uint32_t *)x[i].scales;
1649
+ __m128i scales128 = _mm_set_epi32(
1650
+ ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
1651
+ ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
1652
+ (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
1653
+ (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
1654
+ scales128 = _mm_sub_epi8(scales128, m32);
1655
+ const __m128i scales_0 = _mm_cvtepi8_epi16(scales128);
1656
+ const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128));
1657
+ const __m128i scales[2] = { scales_0, scales_1 };
1658
+
1659
+ // high bit *128*2 from block_q3_K.hmask[QK_K/8]
1660
+ const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]);
1661
+ const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]);
1662
+
1663
+ // integer accumulator
1664
+ __m128i sumi_0 = _mm_setzero_si128();
1665
+ __m128i sumi_1 = _mm_setzero_si128();
1666
+
1667
+ for (int j = 0; j < QK_K/128; ++j) {
1668
+ // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4]
1669
+ const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
1670
+ const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
1671
+
1672
+ // prepare low and high bits
1673
+ const int bit = j << 2;
1674
+
1675
+ const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3);
1676
+ const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3);
1677
+ const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2);
1678
+ const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2);
1679
+
1680
+ const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3);
1681
+ const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3);
1682
+ const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
1683
+ const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
1684
+
1685
+ const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3);
1686
+ const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3);
1687
+ const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
1688
+ const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
1689
+
1690
+ const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3);
1691
+ const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3);
1692
+ const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
1693
+ const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
1694
+
1695
+ // load Q8 quants from block_q8_K.qs[QK_K]
1696
+ const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1697
+ const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1698
+ const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1699
+ const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1700
+ const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1701
+ const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1702
+ const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1703
+ const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1704
+
1705
+ // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
1706
+ // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
1707
+ // and 2 if the high bit was set)
1708
+ __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0);
1709
+ __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1);
1710
+ __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2);
1711
+ __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3);
1712
+ __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4);
1713
+ __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5);
1714
+ __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6);
1715
+ __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7);
1716
+
1717
+ __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0);
1718
+ __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1);
1719
+ __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2);
1720
+ __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3);
1721
+ __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4);
1722
+ __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5);
1723
+ __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6);
1724
+ __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7);
1725
+
1726
+ p16_0 = _mm_sub_epi16(p16_0, q8s_0);
1727
+ p16_1 = _mm_sub_epi16(p16_1, q8s_1);
1728
+ p16_2 = _mm_sub_epi16(p16_2, q8s_2);
1729
+ p16_3 = _mm_sub_epi16(p16_3, q8s_3);
1730
+ p16_4 = _mm_sub_epi16(p16_4, q8s_4);
1731
+ p16_5 = _mm_sub_epi16(p16_5, q8s_5);
1732
+ p16_6 = _mm_sub_epi16(p16_6, q8s_6);
1733
+ p16_7 = _mm_sub_epi16(p16_7, q8s_7);
1734
+
1735
+ // multiply with scales
1736
+ __m128i shuffle = _mm_set1_epi16(0x0100);
1737
+ p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0);
1738
+ shuffle = _mm_add_epi16(shuffle, m2);
1739
+ p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1);
1740
+ shuffle = _mm_add_epi16(shuffle, m2);
1741
+ p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2);
1742
+ shuffle = _mm_add_epi16(shuffle, m2);
1743
+ p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3);
1744
+ shuffle = _mm_add_epi16(shuffle, m2);
1745
+ p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4);
1746
+ shuffle = _mm_add_epi16(shuffle, m2);
1747
+ p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5);
1748
+ shuffle = _mm_add_epi16(shuffle, m2);
1749
+ p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6);
1750
+ shuffle = _mm_add_epi16(shuffle, m2);
1751
+ p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7);
1752
+
1753
+ // accumulate
1754
+ p16_0 = _mm_add_epi32(p16_0, p16_1);
1755
+ p16_2 = _mm_add_epi32(p16_2, p16_3);
1756
+ p16_4 = _mm_add_epi32(p16_4, p16_5);
1757
+ p16_6 = _mm_add_epi32(p16_6, p16_7);
1758
+ sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
1759
+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6));
1760
+
1761
+ }
1762
+
1763
+ // multiply with block scale and accumulate
1764
+ __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
1765
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
1766
+
1767
+ }
1768
+
1769
+ *s = hsum_float_8(acc);
1770
+
1771
+ #else
1772
+ // scalar version
1773
+ // This function is written like this so the compiler can manage to vectorize most of it
1774
+ // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
1775
+ // manually vectorized version above. Every other version I tried would run at least 4 times slower.
1776
+ // The ideal situation would be if we could just write the code once, and the compiler would
1777
+ // automatically produce the best possible set of machine instructions, instead of us having to manually
1778
+ // write vectorized versions for AVX, ARM_NEON, etc.
1779
+
1780
+ int8_t aux8[QK_K];
1781
+ int16_t aux16[8];
1782
+ float sums [8];
1783
+ int32_t aux32[8];
1784
+ memset(sums, 0, 8*sizeof(float));
1785
+
1786
+ uint32_t auxs[4];
1787
+ const int8_t * scales = (const int8_t*)auxs;
1788
+
1789
+ float sumf = 0;
1790
+ for (int i = 0; i < nb; ++i) {
1791
+ const uint8_t * LM_GGML_RESTRICT q3 = x[i].qs;
1792
+ const uint8_t * LM_GGML_RESTRICT hm = x[i].hmask;
1793
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
1794
+ memset(aux32, 0, 8*sizeof(int32_t));
1795
+ int8_t * LM_GGML_RESTRICT a = aux8;
1796
+ uint8_t m = 1;
1797
+ for (int j = 0; j < QK_K; j += 128) {
1798
+ for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
1799
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
1800
+ a += 32; m <<= 1;
1801
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
1802
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
1803
+ a += 32; m <<= 1;
1804
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
1805
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
1806
+ a += 32; m <<= 1;
1807
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
1808
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
1809
+ a += 32; m <<= 1;
1810
+ q3 += 32;
1811
+ }
1812
+ a = aux8;
1813
+
1814
+ memcpy(auxs, x[i].scales, 12);
1815
+ uint32_t tmp = auxs[2];
1816
+ auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
1817
+ auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
1818
+ auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
1819
+ auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
1820
+ for (int j = 0; j < QK_K/16; ++j) {
1821
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1822
+ for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
1823
+ q8 += 8; a += 8;
1824
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1825
+ for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
1826
+ q8 += 8; a += 8;
1827
+ }
1828
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1829
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
1830
+ }
1831
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
1832
+ *s = sumf;
1833
+
1834
+ #endif
1835
+
1836
+ }
1837
+
1838
+ void lm_ggml_vec_dot_q4_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
1839
+ assert(n % QK_K == 0);
1840
+ assert(nrc == 1);
1841
+ UNUSED(nrc);
1842
+ UNUSED(bx);
1843
+ UNUSED(by);
1844
+ UNUSED(bs);
1845
+
1846
+ const block_q4_K * LM_GGML_RESTRICT x = vx;
1847
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
1848
+
1849
+ const int nb = n / QK_K;
1850
+
1851
+ static const uint32_t kmask1 = 0x3f3f3f3f;
1852
+ static const uint32_t kmask2 = 0x0f0f0f0f;
1853
+ static const uint32_t kmask3 = 0x03030303;
1854
+
1855
+ uint32_t utmp[4];
1856
+
1857
+ #if defined __AVX2__
1858
+
1859
+ const __m256i m4 = _mm256_set1_epi8(0xF);
1860
+
1861
+ __m256 acc = _mm256_setzero_ps();
1862
+ __m128 acc_m = _mm_setzero_ps();
1863
+
1864
+ for (int i = 0; i < nb; ++i) {
1865
+
1866
+ const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
1867
+ const float dmin = -y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].dmin);
1868
+
1869
+ memcpy(utmp, x[i].scales, 12);
1870
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1871
+ const uint32_t uaux = utmp[1] & kmask1;
1872
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1873
+ utmp[2] = uaux;
1874
+ utmp[0] &= kmask1;
1875
+
1876
+ const uint8_t * LM_GGML_RESTRICT q4 = x[i].qs;
1877
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
1878
+
1879
+ const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
1880
+
1881
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
1882
+ const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
1883
+ const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
1884
+ acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
1885
+
1886
+ const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
1887
+ const __m256i scales = MM256_SET_M128I(sc128, sc128);
1888
+
1889
+ __m256i sumi = _mm256_setzero_si256();
1890
+
1891
+ for (int j = 0; j < QK_K/64; ++j) {
1892
+
1893
+ const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
1894
+ const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
1895
+
1896
+ const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
1897
+ const __m256i q4l = _mm256_and_si256(q4bits, m4);
1898
+ const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
1899
+
1900
+ const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1901
+ __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
1902
+ p16l = _mm256_madd_epi16(scale_l, p16l);
1903
+
1904
+ const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1905
+ __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
1906
+ p16h = _mm256_madd_epi16(scale_h, p16h);
1907
+ const __m256i sumj = _mm256_add_epi32(p16l, p16h);
1908
+
1909
+ sumi = _mm256_add_epi32(sumi, sumj);
1910
+ }
1911
+
1912
+ __m256 vd = _mm256_set1_ps(d);
1913
+ acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
1914
+
1915
+ }
1916
+
1917
+ acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
1918
+ acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
1919
+
1920
+ *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
1921
+
1922
+ #elif defined __AVX__
1923
+
1924
+ const __m128i m4 = _mm_set1_epi8(0xF);
1925
+ const __m128i m2 = _mm_set1_epi8(0x2);
1926
+
1927
+ __m256 acc = _mm256_setzero_ps();
1928
+ __m128 acc_m = _mm_setzero_ps();
1929
+
1930
+ for (int i = 0; i < nb; ++i) {
1931
+
1932
+ const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
1933
+ const float dmin = -y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].dmin);
1934
+
1935
+ const uint8_t * LM_GGML_RESTRICT q4 = x[i].qs;
1936
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
1937
+
1938
+ memcpy(utmp, x[i].scales, 12);
1939
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1940
+ const uint32_t uaux = utmp[1] & kmask1;
1941
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1942
+ utmp[2] = uaux;
1943
+ utmp[0] &= kmask1;
1944
+
1945
+ const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
1946
+ const __m128i scales = _mm_cvtepu8_epi16(utmps);
1947
+ const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
1948
+
1949
+ const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
1950
+ const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
1951
+ const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
1952
+ const __m128i prod = _mm_madd_epi16(mins, q8s);
1953
+ acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m);
1954
+
1955
+ __m128i sumi_0 = _mm_setzero_si128();
1956
+ __m128i sumi_1 = _mm_setzero_si128();
1957
+
1958
+ __m128i shuffle = _mm_set1_epi16(0x0100);
1959
+ for (int j = 0; j < QK_K/64; ++j) {
1960
+
1961
+ const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle);
1962
+ shuffle = _mm_add_epi16(shuffle, m2);
1963
+ const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle);
1964
+ shuffle = _mm_add_epi16(shuffle, m2);
1965
+
1966
+ __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
1967
+ const __m128i q4l_0 = _mm_and_si128(q4bits, m4);
1968
+ const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
1969
+ q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
1970
+ const __m128i q4l_1 = _mm_and_si128(q4bits, m4);
1971
+ const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
1972
+
1973
+ const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1974
+ __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0);
1975
+ p16l = _mm_madd_epi16(scale_l, p16l);
1976
+ sumi_0 = _mm_add_epi32(sumi_0, p16l);
1977
+ const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1978
+ p16l = _mm_maddubs_epi16(q4l_1, q8l_1);
1979
+ p16l = _mm_madd_epi16(scale_l, p16l);
1980
+ sumi_1 = _mm_add_epi32(sumi_1, p16l);
1981
+
1982
+ const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1983
+ __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0);
1984
+ p16h = _mm_madd_epi16(scale_h, p16h);
1985
+ sumi_0 = _mm_add_epi32(sumi_0, p16h);
1986
+ const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1987
+ p16h = _mm_maddubs_epi16(q4h_1, q8h_1);
1988
+ p16h = _mm_madd_epi16(scale_h, p16h);
1989
+ sumi_1 = _mm_add_epi32(sumi_1, p16h);
1990
+
1991
+ }
1992
+
1993
+ __m256 vd = _mm256_set1_ps(d);
1994
+ __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
1995
+ acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
1996
+
1997
+ }
1998
+
1999
+ acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
2000
+ acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
2001
+
2002
+ *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
2003
+
2004
+ #else
2005
+
2006
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
2007
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
2008
+
2009
+ int8_t aux8[QK_K];
2010
+ int16_t aux16[8];
2011
+ float sums [8];
2012
+ int32_t aux32[8];
2013
+ memset(sums, 0, 8*sizeof(float));
2014
+
2015
+ float sumf = 0;
2016
+ for (int i = 0; i < nb; ++i) {
2017
+ const uint8_t * LM_GGML_RESTRICT q4 = x[i].qs;
2018
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
2019
+ memset(aux32, 0, 8*sizeof(int32_t));
2020
+ int8_t * LM_GGML_RESTRICT a = aux8;
2021
+ for (int j = 0; j < QK_K/64; ++j) {
2022
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
2023
+ a += 32;
2024
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
2025
+ a += 32; q4 += 32;
2026
+ }
2027
+ memcpy(utmp, x[i].scales, 12);
2028
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
2029
+ const uint32_t uaux = utmp[1] & kmask1;
2030
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
2031
+ utmp[2] = uaux;
2032
+ utmp[0] &= kmask1;
2033
+
2034
+ int sumi = 0;
2035
+ for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
2036
+ a = aux8;
2037
+ int is = 0;
2038
+ for (int j = 0; j < QK_K/32; ++j) {
2039
+ int32_t scale = scales[is++];
2040
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2041
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
2042
+ q8 += 8; a += 8;
2043
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2044
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
2045
+ q8 += 8; a += 8;
2046
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2047
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
2048
+ q8 += 8; a += 8;
2049
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2050
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
2051
+ q8 += 8; a += 8;
2052
+ }
2053
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2054
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
2055
+ const float dmin = LM_GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
2056
+ sumf -= dmin * sumi;
2057
+ }
2058
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
2059
+ *s = sumf;
2060
+ #endif
2061
+ }
2062
+
2063
+ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
2064
+ assert(n % QK_K == 0);
2065
+ assert(nrc == 1);
2066
+ UNUSED(nrc);
2067
+ UNUSED(bx);
2068
+ UNUSED(by);
2069
+ UNUSED(bs);
2070
+
2071
+ const block_q5_K * LM_GGML_RESTRICT x = vx;
2072
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
2073
+
2074
+ const int nb = n / QK_K;
2075
+
2076
+ static const uint32_t kmask1 = 0x3f3f3f3f;
2077
+ static const uint32_t kmask2 = 0x0f0f0f0f;
2078
+ static const uint32_t kmask3 = 0x03030303;
2079
+
2080
+ uint32_t utmp[4];
2081
+
2082
+ #if defined __AVX2__
2083
+
2084
+ const __m256i m4 = _mm256_set1_epi8(0xF);
2085
+ const __m128i mzero = _mm_setzero_si128();
2086
+ const __m256i mone = _mm256_set1_epi8(1);
2087
+
2088
+ __m256 acc = _mm256_setzero_ps();
2089
+
2090
+ float summs = 0.f;
2091
+
2092
+ for (int i = 0; i < nb; ++i) {
2093
+ const uint8_t * LM_GGML_RESTRICT q5 = x[i].qs;
2094
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
2095
+
2096
+ const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
2097
+ const float dmin = -y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].dmin);
2098
+
2099
+ memcpy(utmp, x[i].scales, 12);
2100
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
2101
+ const uint32_t uaux = utmp[1] & kmask1;
2102
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
2103
+ utmp[2] = uaux;
2104
+ utmp[0] &= kmask1;
2105
+
2106
+ const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
2107
+
2108
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
2109
+ const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
2110
+ const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
2111
+ const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
2112
+ summs += dmin * _mm_extract_epi32(hsum, 0);
2113
+
2114
+ const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
2115
+ const __m256i scales = MM256_SET_M128I(sc128, sc128);
2116
+
2117
+ const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh);
2118
+ __m256i hmask = mone;
2119
+
2120
+ __m256i sumi = _mm256_setzero_si256();
2121
+
2122
+ int bit = 0;
2123
+
2124
+ for (int j = 0; j < QK_K/64; ++j) {
2125
+
2126
+ const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
2127
+ const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
2128
+
2129
+ const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32;
2130
+
2131
+ const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
2132
+ const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
2133
+ const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
2134
+ hmask = _mm256_slli_epi16(hmask, 1);
2135
+
2136
+ const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
2137
+ const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
2138
+ const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
2139
+ hmask = _mm256_slli_epi16(hmask, 1);
2140
+
2141
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
2142
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
2143
+
2144
+ __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
2145
+ __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
2146
+
2147
+ p16_0 = _mm256_madd_epi16(scale_0, p16_0);
2148
+ p16_1 = _mm256_madd_epi16(scale_1, p16_1);
2149
+
2150
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
2151
+
2152
+ }
2153
+
2154
+ __m256 vd = _mm256_set1_ps(d);
2155
+ acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
2156
+
2157
+ }
2158
+
2159
+ *s = hsum_float_8(acc) + summs;
2160
+
2161
+ #elif defined __AVX__
2162
+
2163
+ const __m128i m4 = _mm_set1_epi8(0xF);
2164
+ const __m128i mzero = _mm_setzero_si128();
2165
+ const __m128i mone = _mm_set1_epi8(1);
2166
+ const __m128i m2 = _mm_set1_epi8(2);
2167
+
2168
+ __m256 acc = _mm256_setzero_ps();
2169
+
2170
+ float summs = 0.f;
2171
+
2172
+ for (int i = 0; i < nb; ++i) {
2173
+
2174
+ const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
2175
+ const float dmin = -y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].dmin);
2176
+
2177
+ const uint8_t * LM_GGML_RESTRICT q5 = x[i].qs;
2178
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
2179
+
2180
+ memcpy(utmp, x[i].scales, 12);
2181
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
2182
+ const uint32_t uaux = utmp[1] & kmask1;
2183
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
2184
+ utmp[2] = uaux;
2185
+ utmp[0] &= kmask1;
2186
+
2187
+ const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
2188
+ const __m128i scales = _mm_cvtepu8_epi16(utmps);
2189
+ const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
2190
+
2191
+ const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
2192
+ const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
2193
+ const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
2194
+ const __m128i prod = _mm_madd_epi16(mins, q8s);
2195
+ const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
2196
+ summs += dmin * _mm_extract_epi32(hsum, 0);
2197
+
2198
+ const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]);
2199
+ const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]);
2200
+ __m128i hmask = mone;
2201
+
2202
+ __m128i sumi_0 = _mm_setzero_si128();
2203
+ __m128i sumi_1 = _mm_setzero_si128();
2204
+
2205
+ int bit = 0;
2206
+
2207
+ __m128i shuffle = _mm_set1_epi16(0x0100);
2208
+ for (int j = 0; j < QK_K/64; ++j) {
2209
+
2210
+ const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
2211
+ shuffle = _mm_add_epi16(shuffle, m2);
2212
+ const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
2213
+ shuffle = _mm_add_epi16(shuffle, m2);
2214
+
2215
+ const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
2216
+ const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
2217
+
2218
+ __m128i q5l_0 = _mm_and_si128(q5bits_0, m4);
2219
+ __m128i q5l_1 = _mm_and_si128(q5bits_1, m4);
2220
+ __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
2221
+ __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
2222
+ __m128i q5_0 = _mm_add_epi8(q5l_0, q5h_0);
2223
+ __m128i q5_1 = _mm_add_epi8(q5l_1, q5h_1);
2224
+ hmask = _mm_slli_epi16(hmask, 1);
2225
+
2226
+ __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2227
+ __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2228
+ __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0);
2229
+ __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1);
2230
+ p16_0 = _mm_madd_epi16(scale_0, p16_0);
2231
+ p16_1 = _mm_madd_epi16(scale_0, p16_1);
2232
+
2233
+ q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4);
2234
+ q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4);
2235
+ q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
2236
+ q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
2237
+ q5_0 = _mm_add_epi8(q5l_0, q5h_0);
2238
+ q5_1 = _mm_add_epi8(q5l_1, q5h_1);
2239
+ hmask = _mm_slli_epi16(hmask, 1);
2240
+
2241
+ q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2242
+ q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2243
+ __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0);
2244
+ __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1);
2245
+ p16_2 = _mm_madd_epi16(scale_1, p16_2);
2246
+ p16_3 = _mm_madd_epi16(scale_1, p16_3);
2247
+
2248
+ sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
2249
+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
2250
+
2251
+ }
2252
+
2253
+ __m256 vd = _mm256_set1_ps(d);
2254
+ __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
2255
+ acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
2256
+
2257
+ }
2258
+
2259
+ *s = hsum_float_8(acc) + summs;
2260
+
2261
+ #else
2262
+
2263
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
2264
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
2265
+
2266
+ int8_t aux8[QK_K];
2267
+ int16_t aux16[8];
2268
+ float sums [8];
2269
+ int32_t aux32[8];
2270
+ memset(sums, 0, 8*sizeof(float));
2271
+
2272
+ float sumf = 0;
2273
+ for (int i = 0; i < nb; ++i) {
2274
+ const uint8_t * LM_GGML_RESTRICT q4 = x[i].qs;
2275
+ const uint8_t * LM_GGML_RESTRICT hm = x[i].qh;
2276
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
2277
+ memset(aux32, 0, 8*sizeof(int32_t));
2278
+ int8_t * LM_GGML_RESTRICT a = aux8;
2279
+ uint8_t m = 1;
2280
+ for (int j = 0; j < QK_K/64; ++j) {
2281
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
2282
+ for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
2283
+ a += 32; m <<= 1;
2284
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
2285
+ for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
2286
+ a += 32; m <<= 1;
2287
+ q4 += 32;
2288
+ }
2289
+ memcpy(utmp, x[i].scales, 12);
2290
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
2291
+ const uint32_t uaux = utmp[1] & kmask1;
2292
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
2293
+ utmp[2] = uaux;
2294
+ utmp[0] &= kmask1;
2295
+
2296
+ int sumi = 0;
2297
+ for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
2298
+ a = aux8;
2299
+ int is = 0;
2300
+ for (int j = 0; j < QK_K/32; ++j) {
2301
+ int32_t scale = scales[is++];
2302
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2303
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
2304
+ q8 += 8; a += 8;
2305
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2306
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
2307
+ q8 += 8; a += 8;
2308
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2309
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
2310
+ q8 += 8; a += 8;
2311
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2312
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
2313
+ q8 += 8; a += 8;
2314
+ }
2315
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2316
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
2317
+ const float dmin = LM_GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
2318
+ sumf -= dmin * sumi;
2319
+ }
2320
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
2321
+ *s = sumf;
2322
+ #endif
2323
+ }
2324
+
2325
+ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
2326
+ assert(n % QK_K == 0);
2327
+ assert(nrc == 1);
2328
+ UNUSED(nrc);
2329
+ UNUSED(bx);
2330
+ UNUSED(by);
2331
+ UNUSED(bs);
2332
+
2333
+ const block_q6_K * LM_GGML_RESTRICT x = vx;
2334
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
2335
+
2336
+ const int nb = n / QK_K;
2337
+
2338
+ #if defined __AVX2__
2339
+
2340
+ const __m256i m4 = _mm256_set1_epi8(0xF);
2341
+ const __m256i m2 = _mm256_set1_epi8(3);
2342
+ const __m256i m32s = _mm256_set1_epi8(32);
2343
+
2344
+ __m256 acc = _mm256_setzero_ps();
2345
+
2346
+ for (int i = 0; i < nb; ++i) {
2347
+
2348
+ const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
2349
+
2350
+ const uint8_t * LM_GGML_RESTRICT q4 = x[i].ql;
2351
+ const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
2352
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
2353
+
2354
+ const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
2355
+
2356
+ __m256i sumi = _mm256_setzero_si256();
2357
+
2358
+ int is = 0;
2359
+
2360
+ for (int j = 0; j < QK_K/128; ++j) {
2361
+
2362
+ const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
2363
+ const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
2364
+ const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
2365
+ const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
2366
+ is += 4;
2367
+
2368
+ const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
2369
+ const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
2370
+ const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32;
2371
+
2372
+ const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4);
2373
+ const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4);
2374
+ const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4);
2375
+ const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4);
2376
+
2377
+ const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
2378
+ const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);
2379
+ const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);
2380
+ const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);
2381
+
2382
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
2383
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
2384
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
2385
+ const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
2386
+
2387
+ __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
2388
+ __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
2389
+ __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);
2390
+ __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);
2391
+
2392
+ __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
2393
+ __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
2394
+ __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);
2395
+ __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);
2396
+
2397
+ p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
2398
+ p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
2399
+ p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
2400
+ p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
2401
+
2402
+ p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
2403
+ p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
2404
+ p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2);
2405
+ p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3);
2406
+
2407
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
2408
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3));
2409
+
2410
+ }
2411
+
2412
+ acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
2413
+ }
2414
+
2415
+ *s = hsum_float_8(acc);
2416
+
2417
+ #elif defined __AVX__
2418
+
2419
+ const __m128i m3 = _mm_set1_epi8(3);
2420
+ const __m128i m15 = _mm_set1_epi8(15);
2421
+
2422
+ __m256 acc = _mm256_setzero_ps();
2423
+
2424
+ for (int i = 0; i < nb; ++i) {
2425
+
2426
+ const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
2427
+
2428
+ const uint8_t * LM_GGML_RESTRICT q4 = x[i].ql;
2429
+ const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
2430
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
2431
+
2432
+ // handle the q6_k -32 offset separately using bsums
2433
+ const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)y[i].bsums);
2434
+ const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)y[i].bsums + 1);
2435
+ const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
2436
+ const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales);
2437
+ const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8));
2438
+ const __m128i q8sclsub_0 = _mm_slli_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), 5);
2439
+ const __m128i q8sclsub_1 = _mm_slli_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), 5);
2440
+
2441
+ __m128i sumi_0 = _mm_setzero_si128();
2442
+ __m128i sumi_1 = _mm_setzero_si128();
2443
+
2444
+ int is = 0;
2445
+
2446
+ for (int j = 0; j < QK_K/128; ++j) {
2447
+
2448
+ const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
2449
+ const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
2450
+
2451
+ const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
2452
+ const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
2453
+ const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(12)), 2);
2454
+ const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(12)), 2);
2455
+ const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(48));
2456
+ const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(48));
2457
+ const __m128i q4h_6 = _mm_srli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(-64)), 2);
2458
+ const __m128i q4h_7 = _mm_srli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(-64)), 2);
2459
+
2460
+ const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
2461
+ const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
2462
+ const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
2463
+ const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
2464
+
2465
+ const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0);
2466
+ const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1);
2467
+ const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2);
2468
+ const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3);
2469
+ const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4);
2470
+ const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5);
2471
+ const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6);
2472
+ const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7);
2473
+
2474
+ const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2475
+ const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2476
+ const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2477
+ const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2478
+ const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2479
+ const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2480
+ const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2481
+ const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2482
+
2483
+ __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
2484
+ __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
2485
+ __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
2486
+ __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3);
2487
+ __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4);
2488
+ __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5);
2489
+ __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
2490
+ __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
2491
+
2492
+ const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
2493
+ const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
2494
+ const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
2495
+ const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
2496
+ is += 4;
2497
+
2498
+ p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
2499
+ p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_0, 8)), p16_1);
2500
+ p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
2501
+ p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_1, 8)), p16_3);
2502
+ p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
2503
+ p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_2, 8)), p16_5);
2504
+ p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
2505
+ p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_3, 8)), p16_7);
2506
+
2507
+ sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
2508
+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
2509
+ sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6));
2510
+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7));
2511
+
2512
+ }
2513
+
2514
+ sumi_0 = _mm_sub_epi32(sumi_0, q8sclsub_0);
2515
+ sumi_1 = _mm_sub_epi32(sumi_1, q8sclsub_1);
2516
+ const __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
2517
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi)), acc);
2518
+ }
2519
+
2520
+ *s = hsum_float_8(acc);
2521
+
2522
+ #else
2523
+
2524
+ int8_t aux8[QK_K];
2525
+ int16_t aux16[8];
2526
+ float sums [8];
2527
+ int32_t aux32[8];
2528
+ memset(sums, 0, 8*sizeof(float));
2529
+
2530
+ float sumf = 0;
2531
+ for (int i = 0; i < nb; ++i) {
2532
+ const uint8_t * LM_GGML_RESTRICT q4 = x[i].ql;
2533
+ const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
2534
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
2535
+ memset(aux32, 0, 8*sizeof(int32_t));
2536
+ int8_t * LM_GGML_RESTRICT a = aux8;
2537
+ for (int j = 0; j < QK_K; j += 128) {
2538
+ for (int l = 0; l < 32; ++l) {
2539
+ a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
2540
+ a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
2541
+ a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
2542
+ a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
2543
+ }
2544
+ a += 128;
2545
+ q4 += 64;
2546
+ qh += 32;
2547
+ }
2548
+ a = aux8;
2549
+ int is = 0;
2550
+ for (int j = 0; j < QK_K/16; ++j) {
2551
+ int scale = x[i].scales[is++];
2552
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2553
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
2554
+ q8 += 8; a += 8;
2555
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2556
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
2557
+ q8 += 8; a += 8;
2558
+ }
2559
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2560
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
2561
+ }
2562
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
2563
+ *s = sumf;
2564
+ #endif
2565
+ }
2566
+
2567
+ #if defined (__AVX__) || defined (__AVX2__)
2568
+ static const int8_t keven_signs_q2xs[1024] = {
2569
+ 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
2570
+ 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
2571
+ 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1,
2572
+ 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1,
2573
+ 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1,
2574
+ 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1,
2575
+ 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1,
2576
+ 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1,
2577
+ 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1,
2578
+ 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1,
2579
+ 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1,
2580
+ 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1,
2581
+ 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1,
2582
+ 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1,
2583
+ 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1,
2584
+ 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1,
2585
+ 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1,
2586
+ 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1,
2587
+ 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1,
2588
+ 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1,
2589
+ 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1,
2590
+ 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1,
2591
+ 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1,
2592
+ 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1,
2593
+ 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1,
2594
+ 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1,
2595
+ 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1,
2596
+ 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1,
2597
+ 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1,
2598
+ 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1,
2599
+ 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
2600
+ 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
2601
+ };
2602
+ #endif
2603
+
2604
+ void lm_ggml_vec_dot_iq2_xxs_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
2605
+ assert(n % QK_K == 0);
2606
+ assert(nrc == 1);
2607
+ UNUSED(nrc);
2608
+ UNUSED(bx);
2609
+ UNUSED(by);
2610
+ UNUSED(bs);
2611
+
2612
+ const block_iq2_xxs * LM_GGML_RESTRICT x = vx;
2613
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
2614
+
2615
+ const int nb = n / QK_K;
2616
+
2617
+ #if defined(__AVX2__)
2618
+
2619
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
2620
+
2621
+ uint32_t aux32[4];
2622
+ const uint8_t * aux8 = (const uint8_t *)aux32;
2623
+
2624
+ __m256 accumf = _mm256_setzero_ps();
2625
+ for (int i = 0; i < nb; ++i) {
2626
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2627
+ const uint16_t * LM_GGML_RESTRICT q2 = x[i].qs;
2628
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
2629
+ __m256i sumi1 = _mm256_setzero_si256();
2630
+ __m256i sumi2 = _mm256_setzero_si256();
2631
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
2632
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
2633
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
2634
+ memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
2635
+ const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
2636
+ const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
2637
+ const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
2638
+ signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
2639
+ const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127],
2640
+ signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]);
2641
+ const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
2642
+ const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
2643
+ const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
2644
+ const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
2645
+ const uint16_t ls1 = aux32[1] >> 28;
2646
+ const uint16_t ls2 = aux32[3] >> 28;
2647
+ const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
2648
+ const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
2649
+ sumi1 = _mm256_add_epi32(sumi1, p1);
2650
+ sumi2 = _mm256_add_epi32(sumi2, p2);
2651
+ }
2652
+
2653
+ accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
2654
+
2655
+ }
2656
+
2657
+ *s = 0.125f * hsum_float_8(accumf);
2658
+
2659
+ #elif defined(__AVX__)
2660
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
2661
+
2662
+ uint32_t aux32[4];
2663
+ const uint8_t * aux8 = (const uint8_t *)aux32;
2664
+
2665
+ __m256 accumf = _mm256_setzero_ps();
2666
+ for (int i = 0; i < nb; ++i) {
2667
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2668
+ const uint16_t * LM_GGML_RESTRICT q2 = x[i].qs;
2669
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
2670
+ __m128i sumi1_0 = _mm_setzero_si128();
2671
+ __m128i sumi1_1 = _mm_setzero_si128();
2672
+ __m128i sumi2_0 = _mm_setzero_si128();
2673
+ __m128i sumi2_1 = _mm_setzero_si128();
2674
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
2675
+ const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2676
+ const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2677
+ const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2678
+ const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2679
+ memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
2680
+ const __m128i q2_1_0 = _mm_set_epi64x(iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
2681
+ const __m128i q2_1_1 = _mm_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]]);
2682
+ const __m128i q2_2_0 = _mm_set_epi64x(iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
2683
+ const __m128i q2_2_1 = _mm_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]]);
2684
+ const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
2685
+ const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]);
2686
+ const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]);
2687
+ const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127]);
2688
+ const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0);
2689
+ const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1);
2690
+ const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0);
2691
+ const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1);
2692
+ const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
2693
+ const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
2694
+ const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
2695
+ const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
2696
+ const uint16_t ls1 = aux32[1] >> 28;
2697
+ const uint16_t ls2 = aux32[3] >> 28;
2698
+ const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
2699
+ const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
2700
+ const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
2701
+ const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
2702
+ sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
2703
+ sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
2704
+ sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
2705
+ sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
2706
+ }
2707
+
2708
+ accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
2709
+
2710
+ }
2711
+
2712
+ *s = 0.125f * hsum_float_8(accumf);
2713
+
2714
+ #else
2715
+
2716
+ uint32_t aux32[2];
2717
+ const uint8_t * aux8 = (const uint8_t *)aux32;
2718
+
2719
+ float sumf = 0.f;
2720
+ for (int i = 0; i < nb; ++i) {
2721
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2722
+ const uint16_t * LM_GGML_RESTRICT q2 = x[i].qs;
2723
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
2724
+ int32_t bsum = 0;
2725
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
2726
+ memcpy(aux32, q2, 2*sizeof(uint32_t));
2727
+ q2 += 4;
2728
+ const uint32_t ls = 2*(aux32[1] >> 28) + 1;
2729
+ int32_t sumi = 0;
2730
+ for (int l = 0; l < 4; ++l) {
2731
+ const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
2732
+ const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
2733
+ for (int j = 0; j < 8; ++j) {
2734
+ sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
2735
+ }
2736
+ q8 += 8;
2737
+ }
2738
+ bsum += sumi * ls;
2739
+ }
2740
+ sumf += d * bsum;
2741
+ }
2742
+ *s = 0.125f * sumf;
2743
+ #endif
2744
+ }
2745
+
2746
+ void lm_ggml_vec_dot_iq2_xs_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
2747
+ assert(n % QK_K == 0);
2748
+ assert(nrc == 1);
2749
+ UNUSED(nrc);
2750
+ UNUSED(bx);
2751
+ UNUSED(by);
2752
+ UNUSED(bs);
2753
+
2754
+ const block_iq2_xs * LM_GGML_RESTRICT x = vx;
2755
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
2756
+
2757
+ const int nb = n / QK_K;
2758
+
2759
+ #if defined(__AVX2__)
2760
+
2761
+ const __m256i mone = _mm256_set1_epi8(1);
2762
+ static const char block_sign_shuffle_mask_1[32] = {
2763
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
2764
+ 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
2765
+ };
2766
+ static const char block_sign_shuffle_mask_2[32] = {
2767
+ 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
2768
+ 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
2769
+ };
2770
+ static const uint8_t bit_selector_mask_bytes[32] = {
2771
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2772
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2773
+ };
2774
+
2775
+ const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes);
2776
+ const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1);
2777
+ const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2);
2778
+
2779
+ static const uint8_t k_bit_helper[32] = {
2780
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
2781
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
2782
+ };
2783
+ const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper);
2784
+ const __m256i m511 = _mm256_set1_epi16(511);
2785
+ const __m128i m4 = _mm_set1_epi8(0xf);
2786
+ const __m128i m1 = _mm_set1_epi8(1);
2787
+
2788
+ uint64_t aux64;
2789
+
2790
+ // somewhat hacky, but gives a significant boost in performance
2791
+ __m256i aux_gindex;
2792
+ const uint16_t * gindex = (const uint16_t *)&aux_gindex;
2793
+
2794
+ __m256 accumf = _mm256_setzero_ps();
2795
+ for (int i = 0; i < nb; ++i) {
2796
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2797
+ const uint16_t * LM_GGML_RESTRICT q2 = x[i].qs;
2798
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
2799
+
2800
+ memcpy(&aux64, x[i].scales, 8);
2801
+ __m128i stmp = _mm_set1_epi64x(aux64);
2802
+ stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4));
2803
+ const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1);
2804
+
2805
+ __m256i sumi1 = _mm256_setzero_si256();
2806
+ __m256i sumi2 = _mm256_setzero_si256();
2807
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
2808
+
2809
+ const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2); q2 += 16;
2810
+ aux_gindex = _mm256_and_si256(q2_data, m511);
2811
+
2812
+ const __m256i partial_sign_bits = _mm256_srli_epi16(q2_data, 9);
2813
+ const __m256i partial_sign_bits_upper = _mm256_srli_epi16(q2_data, 13);
2814
+ const __m256i partial_sign_bits_for_counting = _mm256_xor_si256(partial_sign_bits, partial_sign_bits_upper);
2815
+
2816
+ const __m256i odd_bits = _mm256_shuffle_epi8(bit_helper, partial_sign_bits_for_counting);
2817
+ const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, odd_bits);
2818
+
2819
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
2820
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
2821
+ const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
2822
+ const __m256i q8_4 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
2823
+
2824
+ const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]],
2825
+ iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]);
2826
+ const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]],
2827
+ iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]);
2828
+ const __m256i q2_3 = _mm256_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]],
2829
+ iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]);
2830
+ const __m256i q2_4 = _mm256_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]],
2831
+ iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
2832
+
2833
+ const __m128i full_signs_l = _mm256_castsi256_si128(full_sign_bits);
2834
+ const __m128i full_signs_h = _mm256_extractf128_si256(full_sign_bits, 1);
2835
+ const __m256i full_signs_1 = MM256_SET_M128I(full_signs_l, full_signs_l);
2836
+ const __m256i full_signs_2 = MM256_SET_M128I(full_signs_h, full_signs_h);
2837
+
2838
+ __m256i signs;
2839
+ signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1);
2840
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
2841
+ const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone));
2842
+
2843
+ signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_2);
2844
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
2845
+ const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone));
2846
+
2847
+ signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_1);
2848
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
2849
+ const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone));
2850
+
2851
+ signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_2);
2852
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
2853
+ const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone));
2854
+
2855
+ const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
2856
+ const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
2857
+ const __m256i dot3 = _mm256_maddubs_epi16(q2_3, q8s_3);
2858
+ const __m256i dot4 = _mm256_maddubs_epi16(q2_4, q8s_4);
2859
+
2860
+ const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)));
2861
+ const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)));
2862
+ const __m256i sc3 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2)));
2863
+ const __m256i sc4 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3)));
2864
+
2865
+ sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1));
2866
+ sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2));
2867
+ sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot3, sc3));
2868
+ sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot4, sc4));
2869
+ }
2870
+
2871
+ accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
2872
+
2873
+ }
2874
+
2875
+ *s = 0.125f * hsum_float_8(accumf);
2876
+
2877
+ #elif defined(__AVX__)
2878
+ const __m128i mone = _mm_set1_epi8(1);
2879
+ static const char block_sign_shuffle_mask_1[32] = {
2880
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
2881
+ 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
2882
+ };
2883
+ static const char block_sign_shuffle_mask_2[32] = {
2884
+ 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
2885
+ 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
2886
+ };
2887
+ static const uint8_t bit_selector_mask_bytes[32] = {
2888
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2889
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2890
+ };
2891
+
2892
+ const __m128i bit_selector_mask_0 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes);
2893
+ const __m128i bit_selector_mask_1 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes + 1);
2894
+ const __m128i block_sign_shuffle_1_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1);
2895
+ const __m128i block_sign_shuffle_1_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1 + 1);
2896
+ const __m128i block_sign_shuffle_2_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2);
2897
+ const __m128i block_sign_shuffle_2_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2 + 1);
2898
+
2899
+ static const uint8_t k_bit_helper[32] = {
2900
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
2901
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
2902
+ };
2903
+ const __m128i bit_helper_0 = _mm_loadu_si128((const __m128i*)k_bit_helper);
2904
+ const __m128i bit_helper_1 = _mm_loadu_si128((const __m128i*)k_bit_helper + 1);
2905
+ const __m128i m511 = _mm_set1_epi16(511);
2906
+ const __m128i m4 = _mm_set1_epi8(0xf);
2907
+ const __m128i m1 = _mm_set1_epi8(1);
2908
+
2909
+ uint64_t aux64;
2910
+
2911
+ // somewhat hacky, but gives a significant boost in performance
2912
+ __m256i aux_gindex;
2913
+ const uint16_t * gindex = (const uint16_t *)&aux_gindex;
2914
+
2915
+ __m256 accumf = _mm256_setzero_ps();
2916
+ for (int i = 0; i < nb; ++i) {
2917
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2918
+ const uint16_t * LM_GGML_RESTRICT q2 = x[i].qs;
2919
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
2920
+
2921
+ memcpy(&aux64, x[i].scales, 8);
2922
+ __m128i stmp = _mm_set1_epi64x(aux64);
2923
+ stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4));
2924
+ const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1);
2925
+
2926
+ __m128i sumi1_0 = _mm_setzero_si128();
2927
+ __m128i sumi1_1 = _mm_setzero_si128();
2928
+ __m128i sumi2_0 = _mm_setzero_si128();
2929
+ __m128i sumi2_1 = _mm_setzero_si128();
2930
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
2931
+
2932
+ const __m128i q2_data_0 = _mm_loadu_si128((const __m128i*)q2);
2933
+ const __m128i q2_data_1 = _mm_loadu_si128((const __m128i*)q2 + 1); q2 += 16;
2934
+ aux_gindex = MM256_SET_M128I(_mm_and_si128(q2_data_1, m511), _mm_and_si128(q2_data_0, m511));
2935
+
2936
+ const __m128i partial_sign_bits_0 = _mm_srli_epi16(q2_data_0, 9);
2937
+ const __m128i partial_sign_bits_1 = _mm_srli_epi16(q2_data_1, 9);
2938
+ const __m128i partial_sign_bits_upper_0 = _mm_srli_epi16(q2_data_0, 13);
2939
+ const __m128i partial_sign_bits_upper_1 = _mm_srli_epi16(q2_data_1, 13);
2940
+ const __m128i partial_sign_bits_for_counting_0 = _mm_xor_si128(partial_sign_bits_0, partial_sign_bits_upper_0);
2941
+ const __m128i partial_sign_bits_for_counting_1 = _mm_xor_si128(partial_sign_bits_1, partial_sign_bits_upper_1);
2942
+
2943
+ const __m128i odd_bits_0 = _mm_shuffle_epi8(bit_helper_0, partial_sign_bits_for_counting_0);
2944
+ const __m128i odd_bits_1 = _mm_shuffle_epi8(bit_helper_1, partial_sign_bits_for_counting_1);
2945
+ const __m128i full_sign_bits_0 = _mm_or_si128(partial_sign_bits_0, odd_bits_0);
2946
+ const __m128i full_sign_bits_1 = _mm_or_si128(partial_sign_bits_1, odd_bits_1);
2947
+
2948
+ const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2949
+ const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2950
+ const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2951
+ const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2952
+ const __m128i q8_3_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2953
+ const __m128i q8_3_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2954
+ const __m128i q8_4_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2955
+ const __m128i q8_4_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2956
+
2957
+ const __m128i q2_1_0 = _mm_set_epi64x(iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]);
2958
+ const __m128i q2_1_1 = _mm_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]]);
2959
+ const __m128i q2_2_0 = _mm_set_epi64x(iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]);
2960
+ const __m128i q2_2_1 = _mm_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]]);
2961
+ const __m128i q2_3_0 = _mm_set_epi64x(iq2xs_grid[gindex[9]], iq2xs_grid[gindex[8]]);
2962
+ const __m128i q2_3_1 = _mm_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]]);
2963
+ const __m128i q2_4_0 = _mm_set_epi64x(iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
2964
+ const __m128i q2_4_1 = _mm_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]]);
2965
+
2966
+ // AVX2 full_signs_1 is full_sign_bits_0 here
2967
+ // AVX2 full_signs_2 is full_sign_bits_1 here
2968
+ __m128i signs_0, signs_1;
2969
+ signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_0);
2970
+ signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_1);
2971
+ signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
2972
+ signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
2973
+ const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, _mm_or_si128(signs_0, mone));
2974
+ const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, _mm_or_si128(signs_1, mone));
2975
+
2976
+ signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_0);
2977
+ signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_1);
2978
+ signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
2979
+ signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
2980
+ const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, _mm_or_si128(signs_0, mone));
2981
+ const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, _mm_or_si128(signs_1, mone));
2982
+
2983
+ signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_0);
2984
+ signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_1);
2985
+ signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
2986
+ signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
2987
+ const __m128i q8s_3_0 = _mm_sign_epi8(q8_3_0, _mm_or_si128(signs_0, mone));
2988
+ const __m128i q8s_3_1 = _mm_sign_epi8(q8_3_1, _mm_or_si128(signs_1, mone));
2989
+
2990
+ signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_0);
2991
+ signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_1);
2992
+ signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
2993
+ signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
2994
+ const __m128i q8s_4_0 = _mm_sign_epi8(q8_4_0, _mm_or_si128(signs_0, mone));
2995
+ const __m128i q8s_4_1 = _mm_sign_epi8(q8_4_1, _mm_or_si128(signs_1, mone));
2996
+
2997
+ const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
2998
+ const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
2999
+ const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
3000
+ const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
3001
+ const __m128i dot3_0 = _mm_maddubs_epi16(q2_3_0, q8s_3_0);
3002
+ const __m128i dot3_1 = _mm_maddubs_epi16(q2_3_1, q8s_3_1);
3003
+ const __m128i dot4_0 = _mm_maddubs_epi16(q2_4_0, q8s_4_0);
3004
+ const __m128i dot4_1 = _mm_maddubs_epi16(q2_4_1, q8s_4_1);
3005
+
3006
+ __m128i sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0));
3007
+ const __m128i sc1_0 = _mm_cvtepi8_epi16(sc_tmp);
3008
+ const __m128i sc1_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
3009
+ sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1));
3010
+ const __m128i sc2_0 = _mm_cvtepi8_epi16(sc_tmp);
3011
+ const __m128i sc2_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
3012
+ sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2));
3013
+ const __m128i sc3_0 = _mm_cvtepi8_epi16(sc_tmp);
3014
+ const __m128i sc3_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
3015
+ sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3));
3016
+ const __m128i sc4_0 = _mm_cvtepi8_epi16(sc_tmp);
3017
+ const __m128i sc4_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
3018
+
3019
+ sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot1_0, sc1_0));
3020
+ sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot1_1, sc1_1));
3021
+ sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot2_0, sc2_0));
3022
+ sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot2_1, sc2_1));
3023
+ sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot3_0, sc3_0));
3024
+ sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot3_1, sc3_1));
3025
+ sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot4_0, sc4_0));
3026
+ sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot4_1, sc4_1));
3027
+ }
3028
+
3029
+ accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
3030
+
3031
+ }
3032
+
3033
+ *s = 0.125f * hsum_float_8(accumf);
3034
+
3035
+ #else
3036
+
3037
+ float sumf = 0.f;
3038
+ for (int i = 0; i < nb; ++i) {
3039
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3040
+ const uint16_t * LM_GGML_RESTRICT q2 = x[i].qs;
3041
+ const uint8_t * LM_GGML_RESTRICT sc = x[i].scales;
3042
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
3043
+ int32_t bsum = 0;
3044
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
3045
+ const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1;
3046
+ const uint16_t ls2 = 2*(sc[ib32] >> 4) + 1;
3047
+ int32_t sumi = 0;
3048
+ for (int l = 0; l < 2; ++l) {
3049
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
3050
+ const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
3051
+ for (int j = 0; j < 8; ++j) {
3052
+ sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
3053
+ }
3054
+ q8 += 8;
3055
+ }
3056
+ bsum += sumi * ls1;
3057
+ sumi = 0;
3058
+ for (int l = 2; l < 4; ++l) {
3059
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
3060
+ const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
3061
+ for (int j = 0; j < 8; ++j) {
3062
+ sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
3063
+ }
3064
+ q8 += 8;
3065
+ }
3066
+ bsum += sumi * ls2;
3067
+ q2 += 4;
3068
+ }
3069
+ sumf += d * bsum;
3070
+ }
3071
+ *s = 0.125f * sumf;
3072
+ #endif
3073
+ }
3074
+
3075
+ void lm_ggml_vec_dot_iq2_s_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
3076
+ assert(n % QK_K == 0);
3077
+ assert(nrc == 1);
3078
+ UNUSED(nrc);
3079
+ UNUSED(bx);
3080
+ UNUSED(by);
3081
+ UNUSED(bs);
3082
+
3083
+ const block_iq2_s * LM_GGML_RESTRICT x = vx;
3084
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
3085
+
3086
+ const int nb = n / QK_K;
3087
+
3088
+ #if defined(__AVX2__)
3089
+
3090
+ static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
3091
+ 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
3092
+ };
3093
+
3094
+ static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
3095
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
3096
+ };
3097
+
3098
+ const __m128i m4 = _mm_set1_epi8(0xf);
3099
+ const __m128i m1 = _mm_set1_epi8(1);
3100
+
3101
+ const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1);
3102
+ const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2);
3103
+
3104
+ uint64_t aux64;
3105
+
3106
+ __m256 accumf = _mm256_setzero_ps();
3107
+ for (int i = 0; i < nb; ++i) {
3108
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3109
+ const uint8_t * LM_GGML_RESTRICT qs = x[i].qs;
3110
+ const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
3111
+ const uint16_t * LM_GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8);
3112
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
3113
+
3114
+ memcpy(&aux64, x[i].scales, 8);
3115
+ const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1);
3116
+ const __m256i scales16 = _mm256_cvtepi8_epi16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15
3117
+
3118
+ __m256i sumi1 = _mm256_setzero_si256();
3119
+ __m256i sumi2 = _mm256_setzero_si256();
3120
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3121
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
3122
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
3123
+ const __m256i q2_1 = _mm256_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],
3124
+ iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)],
3125
+ iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],
3126
+ iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);
3127
+ const __m256i q2_2 = _mm256_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],
3128
+ iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)],
3129
+ iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],
3130
+ iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
3131
+ qs += 8;
3132
+
3133
+ __m256i aux256 = _mm256_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16));
3134
+ aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
3135
+ const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2);
3136
+ const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1);
3137
+
3138
+ aux256 = _mm256_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16));
3139
+ aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
3140
+ const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2);
3141
+ const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2);
3142
+
3143
+ signs += 4;
3144
+
3145
+ const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1
3146
+ const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3
3147
+
3148
+ const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+0)));
3149
+ const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+1)));
3150
+ sumi1 = _mm256_add_epi32(sumi1, p1);
3151
+ sumi2 = _mm256_add_epi32(sumi2, p2);
3152
+ }
3153
+
3154
+ accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
3155
+
3156
+ }
3157
+
3158
+ *s = 0.125f * hsum_float_8(accumf);
3159
+
3160
+ #elif defined(__AVX__)
3161
+ static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
3162
+ 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
3163
+ };
3164
+
3165
+ static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
3166
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
3167
+ };
3168
+
3169
+ const __m128i m4 = _mm_set1_epi8(0xf);
3170
+ const __m128i m1 = _mm_set1_epi8(1);
3171
+
3172
+ const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1);
3173
+ const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1);
3174
+ const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);
3175
+ const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);
3176
+
3177
+ uint64_t aux64;
3178
+
3179
+ __m256 accumf = _mm256_setzero_ps();
3180
+ for (int i = 0; i < nb; ++i) {
3181
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3182
+ const uint8_t * LM_GGML_RESTRICT qs = x[i].qs;
3183
+ const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
3184
+ const uint16_t * LM_GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8);
3185
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
3186
+
3187
+ memcpy(&aux64, x[i].scales, 8);
3188
+ const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1);
3189
+ const __m128i scales16_0 = _mm_cvtepi8_epi16(scales8);
3190
+ const __m128i scales16_1 = _mm_cvtepi8_epi16(_mm_srli_si128(scales8, 8));
3191
+
3192
+ __m128i sumi1_0 = _mm_setzero_si128();
3193
+ __m128i sumi1_1 = _mm_setzero_si128();
3194
+ __m128i sumi2_0 = _mm_setzero_si128();
3195
+ __m128i sumi2_1 = _mm_setzero_si128();
3196
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3197
+ const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3198
+ const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3199
+ const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3200
+ const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3201
+ const __m128i q2_1_0 = _mm_set_epi64x(iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],
3202
+ iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);
3203
+ const __m128i q2_1_1 = _mm_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],
3204
+ iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)]);
3205
+ const __m128i q2_2_0 = _mm_set_epi64x(iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],
3206
+ iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
3207
+ const __m128i q2_2_1 = _mm_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],
3208
+ iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)]);
3209
+ qs += 8;
3210
+
3211
+ __m128i aux128_0 = _mm_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16));
3212
+ __m128i aux128_1 = aux128_0;
3213
+ aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
3214
+ aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
3215
+ const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
3216
+ const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
3217
+ const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0);
3218
+ const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1);
3219
+
3220
+ aux128_0 = _mm_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16));
3221
+ aux128_1 = aux128_0;
3222
+ aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
3223
+ aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
3224
+ const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
3225
+ const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
3226
+ const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0);
3227
+ const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1);
3228
+
3229
+ signs += 4;
3230
+
3231
+ const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
3232
+ const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
3233
+ const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
3234
+ const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
3235
+
3236
+ const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 0)));
3237
+ const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 1)));
3238
+ const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 0)));
3239
+ const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 1)));
3240
+ sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
3241
+ sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
3242
+ sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
3243
+ sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
3244
+ }
3245
+
3246
+ accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
3247
+
3248
+ }
3249
+
3250
+ *s = 0.125f * hsum_float_8(accumf);
3251
+
3252
+ #else
3253
+
3254
+ float sumf = 0;
3255
+ for (int i = 0; i < nb; i++) {
3256
+
3257
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3258
+ const int8_t * q8 = y[i].qs;
3259
+ const uint8_t * qs = x[i].qs;
3260
+ const uint8_t * qh = x[i].qh;
3261
+ const uint8_t * signs = qs + QK_K/8;
3262
+
3263
+ int bsum = 0;
3264
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
3265
+ int ls1 = 1 + 2*(x[i].scales[ib32] & 0xf);
3266
+ int ls2 = 1 + 2*(x[i].scales[ib32] >> 4);
3267
+ int sumi1 = 0, sumi2 = 0;
3268
+ for (int l = 0; l < 2; ++l) {
3269
+ const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));
3270
+ for (int j = 0; j < 8; ++j) {
3271
+ sumi1 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1);
3272
+ }
3273
+ q8 += 8;
3274
+ }
3275
+ for (int l = 2; l < 4; ++l) {
3276
+ const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));
3277
+ for (int j = 0; j < 8; ++j) {
3278
+ sumi2 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1);
3279
+ }
3280
+ q8 += 8;
3281
+ }
3282
+ bsum += ls1 * sumi1 + ls2 * sumi2;
3283
+ qs += 4;
3284
+ signs += 4;
3285
+ }
3286
+
3287
+ sumf += d * bsum;
3288
+ }
3289
+
3290
+ *s = 0.125f * sumf;
3291
+
3292
+ #endif
3293
+
3294
+ }
3295
+
3296
+ void lm_ggml_vec_dot_iq3_xxs_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
3297
+ assert(n % QK_K == 0);
3298
+ assert(nrc == 1);
3299
+ UNUSED(nrc);
3300
+ UNUSED(bx);
3301
+ UNUSED(by);
3302
+ UNUSED(bs);
3303
+
3304
+ const block_iq3_xxs * LM_GGML_RESTRICT x = vx;
3305
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
3306
+
3307
+ const int nb = n / QK_K;
3308
+
3309
+ #if defined(__AVX2__)
3310
+
3311
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
3312
+
3313
+ uint32_t aux32[2];
3314
+
3315
+ __m256 accumf = _mm256_setzero_ps();
3316
+ for (int i = 0; i < nb; ++i) {
3317
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3318
+ const uint8_t * LM_GGML_RESTRICT q3 = x[i].qs;
3319
+ const uint8_t * LM_GGML_RESTRICT gas = x[i].qs + QK_K/4;
3320
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
3321
+ __m256i sumi1 = _mm256_setzero_si256();
3322
+ __m256i sumi2 = _mm256_setzero_si256();
3323
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3324
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
3325
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
3326
+ const __m256i q2_1 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
3327
+ iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
3328
+ q3 += 8;
3329
+ const __m256i q2_2 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
3330
+ iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
3331
+ q3 += 8;
3332
+ memcpy(aux32, gas, 8); gas += 8;
3333
+ const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127],
3334
+ signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]);
3335
+ const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
3336
+ signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
3337
+ const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
3338
+ const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
3339
+ const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
3340
+ const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
3341
+ const uint16_t ls1 = aux32[0] >> 28;
3342
+ const uint16_t ls2 = aux32[1] >> 28;
3343
+ const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
3344
+ const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
3345
+ sumi1 = _mm256_add_epi32(sumi1, p1);
3346
+ sumi2 = _mm256_add_epi32(sumi2, p2);
3347
+ }
3348
+
3349
+ accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
3350
+
3351
+ }
3352
+
3353
+ *s = 0.25f * hsum_float_8(accumf);
3354
+
3355
+ #elif defined(__AVX__)
3356
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
3357
+
3358
+ uint32_t aux32[2];
3359
+
3360
+ __m256 accumf = _mm256_setzero_ps();
3361
+ for (int i = 0; i < nb; ++i) {
3362
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3363
+ const uint8_t * LM_GGML_RESTRICT q3 = x[i].qs;
3364
+ const uint8_t * LM_GGML_RESTRICT gas = x[i].qs + QK_K/4;
3365
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
3366
+ __m128i sumi1_0 = _mm_setzero_si128();
3367
+ __m128i sumi1_1 = _mm_setzero_si128();
3368
+ __m128i sumi2_0 = _mm_setzero_si128();
3369
+ __m128i sumi2_1 = _mm_setzero_si128();
3370
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3371
+ const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3372
+ const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3373
+ const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3374
+ const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3375
+ const __m128i q2_1_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
3376
+ const __m128i q2_1_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]);
3377
+ q3 += 8;
3378
+ const __m128i q2_2_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
3379
+ const __m128i q2_2_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]);
3380
+ q3 += 8;
3381
+ memcpy(aux32, gas, 8); gas += 8;
3382
+ const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]);
3383
+ const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127]);
3384
+ const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
3385
+ const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]);
3386
+ const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0);
3387
+ const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1);
3388
+ const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0);
3389
+ const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1);
3390
+ const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
3391
+ const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
3392
+ const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
3393
+ const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
3394
+ const uint16_t ls1 = aux32[0] >> 28;
3395
+ const uint16_t ls2 = aux32[1] >> 28;
3396
+ const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
3397
+ const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
3398
+ const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
3399
+ const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
3400
+ sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
3401
+ sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
3402
+ sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
3403
+ sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
3404
+ }
3405
+
3406
+ accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
3407
+
3408
+ }
3409
+
3410
+ *s = 0.25f * hsum_float_8(accumf);
3411
+
3412
+ #else
3413
+
3414
+ uint32_t aux32;
3415
+
3416
+ float sumf = 0.f;
3417
+ for (int i = 0; i < nb; ++i) {
3418
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3419
+ const uint8_t * LM_GGML_RESTRICT q3 = x[i].qs;
3420
+ const uint8_t * LM_GGML_RESTRICT gas = x[i].qs + QK_K/4;
3421
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
3422
+ int32_t bsum = 0;
3423
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
3424
+ memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t);
3425
+ const uint32_t ls = 2*(aux32 >> 28) + 1;
3426
+ int32_t sumi = 0;
3427
+ for (int l = 0; l < 4; ++l) {
3428
+ const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]);
3429
+ const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]);
3430
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
3431
+ for (int j = 0; j < 4; ++j) {
3432
+ sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1);
3433
+ sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1);
3434
+ }
3435
+ q8 += 8;
3436
+ }
3437
+ q3 += 8;
3438
+ bsum += sumi * ls;
3439
+ }
3440
+ sumf += d * bsum;
3441
+ }
3442
+ *s = 0.25f * sumf;
3443
+ #endif
3444
+ }
3445
+
3446
+ void lm_ggml_vec_dot_iq3_s_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
3447
+ assert(n % QK_K == 0);
3448
+ assert(nrc == 1);
3449
+ UNUSED(nrc);
3450
+ UNUSED(bx);
3451
+ UNUSED(by);
3452
+ UNUSED(bs);
3453
+
3454
+ const block_iq3_s * LM_GGML_RESTRICT x = vx;
3455
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
3456
+
3457
+ const int nb = n / QK_K;
3458
+
3459
+ #if defined(__AVX2__)
3460
+
3461
+ static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
3462
+ 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
3463
+ };
3464
+
3465
+ static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
3466
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
3467
+ };
3468
+
3469
+ const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1);
3470
+ const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2);
3471
+
3472
+ const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8);
3473
+ const __m256i idx_mask = _mm256_set1_epi32(256);
3474
+
3475
+ typedef union {
3476
+ __m256i vec[2];
3477
+ uint32_t index[16];
3478
+ } index_t;
3479
+
3480
+ index_t idx;
3481
+
3482
+ __m256 accumf = _mm256_setzero_ps();
3483
+ for (int i = 0; i < nb; ++i) {
3484
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3485
+ const uint8_t * LM_GGML_RESTRICT qs = x[i].qs;
3486
+ const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
3487
+ const uint16_t * LM_GGML_RESTRICT signs = (const uint16_t *)x[i].signs;
3488
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
3489
+ __m256i sumi1 = _mm256_setzero_si256();
3490
+ __m256i sumi2 = _mm256_setzero_si256();
3491
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3492
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
3493
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
3494
+ const __m256i idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs)); qs += 16;
3495
+ idx.vec[0] = _mm256_set1_epi32(qh[ib32+0]);
3496
+ idx.vec[1] = _mm256_set1_epi32(qh[ib32+1]);
3497
+ idx.vec[0] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[0], idx_shift), idx_mask);
3498
+ idx.vec[1] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[1], idx_shift), idx_mask);
3499
+ idx.vec[0] = _mm256_or_si256(idx.vec[0], _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l)));
3500
+ idx.vec[1] = _mm256_or_si256(idx.vec[1], _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l, 1)));
3501
+
3502
+ // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange.
3503
+ //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4);
3504
+ //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4);
3505
+ const __m256i q2_1 = _mm256_set_epi32(
3506
+ iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]],
3507
+ iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]
3508
+ );
3509
+ const __m256i q2_2 = _mm256_set_epi32(
3510
+ iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]],
3511
+ iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]]
3512
+ );
3513
+
3514
+ __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16));
3515
+ aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
3516
+ const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2);
3517
+ const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1);
3518
+
3519
+ aux256 = _mm256_set1_epi32(signs[2] | (signs[3] << 16));
3520
+ aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
3521
+ const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2);
3522
+ const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2);
3523
+
3524
+ signs += 4;
3525
+
3526
+ const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
3527
+ const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
3528
+ const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
3529
+ const uint16_t ls2 = x[i].scales[ib32/2] >> 4;
3530
+ const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
3531
+ const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
3532
+ sumi1 = _mm256_add_epi32(sumi1, p1);
3533
+ sumi2 = _mm256_add_epi32(sumi2, p2);
3534
+ }
3535
+
3536
+ accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
3537
+
3538
+ }
3539
+
3540
+ *s = hsum_float_8(accumf);
3541
+
3542
+ #elif defined(__AVX__)
3543
+ static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
3544
+ 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
3545
+ };
3546
+
3547
+ static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
3548
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
3549
+ };
3550
+
3551
+ const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1);
3552
+ const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1);
3553
+ const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);
3554
+ const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);
3555
+
3556
+ const __m128i idx_mul_0 = _mm_set_epi32(32, 64, 128, 256);
3557
+ const __m128i idx_mul_1 = _mm_set_epi32(2, 4, 8, 16);
3558
+ const __m128i idx_mask = _mm_set1_epi32(256);
3559
+
3560
+ typedef union {
3561
+ __m128i vec[4];
3562
+ uint32_t index[16];
3563
+ } index_t;
3564
+
3565
+ index_t idx;
3566
+
3567
+ __m256 accumf = _mm256_setzero_ps();
3568
+ for (int i = 0; i < nb; ++i) {
3569
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3570
+ const uint8_t * LM_GGML_RESTRICT qs = x[i].qs;
3571
+ const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
3572
+ const uint16_t * LM_GGML_RESTRICT signs = (const uint16_t *)x[i].signs;
3573
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
3574
+ __m128i sumi1_0 = _mm_setzero_si128();
3575
+ __m128i sumi1_1 = _mm_setzero_si128();
3576
+ __m128i sumi2_0 = _mm_setzero_si128();
3577
+ __m128i sumi2_1 = _mm_setzero_si128();
3578
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3579
+ const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3580
+ const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3581
+ const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3582
+ const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3583
+ const __m128i qs_tmp = _mm_loadu_si128((const __m128i *)qs);
3584
+ const __m128i idx_l_0 = _mm_cvtepu8_epi16(qs_tmp);
3585
+ const __m128i idx_l_1 = _mm_cvtepu8_epi16(_mm_srli_si128(qs_tmp, 8)); qs += 16;
3586
+ idx.vec[0] = _mm_set1_epi32(qh[ib32+0]);
3587
+ idx.vec[1] = idx.vec[0];
3588
+ idx.vec[2] = _mm_set1_epi32(qh[ib32+1]);
3589
+ idx.vec[3] = idx.vec[2];
3590
+
3591
+ idx.vec[0] = _mm_and_si128(_mm_mullo_epi32(idx.vec[0], idx_mul_0), idx_mask);
3592
+ idx.vec[1] = _mm_and_si128(_mm_mullo_epi32(idx.vec[1], idx_mul_1), idx_mask);
3593
+ idx.vec[2] = _mm_and_si128(_mm_mullo_epi32(idx.vec[2], idx_mul_0), idx_mask);
3594
+ idx.vec[3] = _mm_and_si128(_mm_mullo_epi32(idx.vec[3], idx_mul_1), idx_mask);
3595
+
3596
+ idx.vec[0] = _mm_or_si128(idx.vec[0], _mm_cvtepi16_epi32(idx_l_0));
3597
+ idx.vec[1] = _mm_or_si128(idx.vec[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8)));
3598
+ idx.vec[2] = _mm_or_si128(idx.vec[2], _mm_cvtepi16_epi32(idx_l_1));
3599
+ idx.vec[3] = _mm_or_si128(idx.vec[3], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_1, 8)));
3600
+
3601
+ const __m128i q2_1_0 = _mm_set_epi32(iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]);
3602
+ const __m128i q2_1_1 = _mm_set_epi32(iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]]);
3603
+ const __m128i q2_2_0 = _mm_set_epi32(iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[9]], iq3s_grid[idx.index[8]]);
3604
+ const __m128i q2_2_1 = _mm_set_epi32(iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]]);
3605
+
3606
+ __m128i aux128_0 = _mm_set1_epi32(signs[0] | (signs[1] << 16));
3607
+ __m128i aux128_1 = aux128_0;
3608
+ aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
3609
+ aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
3610
+ const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
3611
+ const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
3612
+ const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0);
3613
+ const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1);
3614
+
3615
+ aux128_0 = _mm_set1_epi32(signs[2] | (signs[3] << 16));
3616
+ aux128_1 = aux128_0;
3617
+ aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
3618
+ aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
3619
+ const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
3620
+ const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
3621
+ const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0);
3622
+ const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1);
3623
+
3624
+ signs += 4;
3625
+
3626
+ const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
3627
+ const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
3628
+ const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
3629
+ const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
3630
+ const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
3631
+ const uint16_t ls2 = x[i].scales[ib32/2] >> 4;
3632
+ const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
3633
+ const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
3634
+ const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
3635
+ const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
3636
+ sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
3637
+ sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
3638
+ sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
3639
+ sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
3640
+ }
3641
+
3642
+ accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
3643
+
3644
+ }
3645
+
3646
+ *s = hsum_float_8(accumf);
3647
+
3648
+ #else
3649
+
3650
+ float sumf = 0.f;
3651
+ for (int i = 0; i < nb; ++i) {
3652
+ const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3653
+ const uint8_t * LM_GGML_RESTRICT qs = x[i].qs;
3654
+ const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
3655
+ const uint8_t * LM_GGML_RESTRICT signs = x[i].signs;
3656
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
3657
+ int32_t bsum = 0;
3658
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3659
+ const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1;
3660
+ const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1;
3661
+ int32_t sumi = 0;
3662
+ for (int l = 0; l < 4; ++l) {
3663
+ const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256)));
3664
+ const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256)));
3665
+ for (int j = 0; j < 4; ++j) {
3666
+ sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
3667
+ sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
3668
+ }
3669
+ q8 += 8;
3670
+ }
3671
+ qs += 8;
3672
+ signs += 4;
3673
+ bsum += sumi * ls1;
3674
+ sumi = 0;
3675
+ for (int l = 0; l < 4; ++l) {
3676
+ const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256)));
3677
+ const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256)));
3678
+ for (int j = 0; j < 4; ++j) {
3679
+ sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
3680
+ sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
3681
+ }
3682
+ q8 += 8;
3683
+ }
3684
+ qs += 8;
3685
+ signs += 4;
3686
+ bsum += sumi * ls2;
3687
+ }
3688
+ sumf += d * bsum;
3689
+ }
3690
+ *s = sumf;
3691
+ #endif
3692
+ }
3693
+
3694
+ #if defined(__AVX2__)
3695
+ static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
3696
+ const __m256i ax = _mm256_sign_epi8(x, x);
3697
+ const __m256i sy = _mm256_sign_epi8(y, x);
3698
+ return _mm256_maddubs_epi16(ax, sy);
3699
+ }
3700
+ #endif
3701
+
3702
+ void lm_ggml_vec_dot_iq1_s_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
3703
+ assert(n % QK_K == 0);
3704
+ assert(nrc == 1);
3705
+ UNUSED(nrc);
3706
+ UNUSED(bx);
3707
+ UNUSED(by);
3708
+ UNUSED(bs);
3709
+
3710
+ const block_iq1_s * LM_GGML_RESTRICT x = vx;
3711
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
3712
+
3713
+ const int nb = n / QK_K;
3714
+
3715
+ #if defined __AVX2__
3716
+
3717
+ __m256 accum = _mm256_setzero_ps();
3718
+ float accum1 = 0;
3719
+ for (int i = 0; i < nb; ++i) {
3720
+
3721
+ const int8_t * q8 = y[i].qs;
3722
+ const uint8_t * qs = x[i].qs;
3723
+ const uint16_t * qh = x[i].qh;
3724
+
3725
+ __m256i sumi = _mm256_setzero_si256();
3726
+ int sumi1 = 0;
3727
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
3728
+ #ifdef __BMI2__
3729
+ const uint64_t packed_idx1 = _pdep_u64(*(const uint32_t *)qs, 0x00ff00ff00ff00ffULL) | _pdep_u64(qh[ib], 0x700070007000700ULL);
3730
+ const uint64_t packed_idx2 = _pdep_u64(*(const uint32_t *)(qs + 4), 0x00ff00ff00ff00ffULL) | _pdep_u64(qh[ib + 1], 0x700070007000700ULL);
3731
+ const uint16_t *idx1 = (const uint16_t *)(&packed_idx1);
3732
+ const uint16_t *idx2 = (const uint16_t *)(&packed_idx2);
3733
+ const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[idx1[3]], iq1s_grid[idx1[2]], iq1s_grid[idx1[1]], iq1s_grid[idx1[0]]);
3734
+ const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[idx2[3]], iq1s_grid[idx2[2]], iq1s_grid[idx2[1]], iq1s_grid[idx2[0]]);
3735
+ #else
3736
+ const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)],
3737
+ iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
3738
+ const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)],
3739
+ iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
3740
+ #endif
3741
+ qs += 8;
3742
+ const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
3743
+ const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
3744
+
3745
+ const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
3746
+ const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
3747
+ const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
3748
+ const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
3749
+ const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(ls1));
3750
+ const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(ls2));
3751
+
3752
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p1, p2));
3753
+ sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1
3754
+ + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;
3755
+ }
3756
+
3757
+ const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
3758
+ accum = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi), accum);
3759
+ accum1 += d * sumi1;
3760
+
3761
+ }
3762
+
3763
+ *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
3764
+
3765
+ #elif defined __AVX__
3766
+ __m256 accum = _mm256_setzero_ps();
3767
+ float accum1 = 0;
3768
+ for (int i = 0; i < nb; ++i) {
3769
+
3770
+ const int8_t * q8 = y[i].qs;
3771
+ const uint8_t * qs = x[i].qs;
3772
+ const uint16_t * qh = x[i].qh;
3773
+
3774
+ __m128i sumi1_0 = _mm_setzero_si128();
3775
+ __m128i sumi1_1 = _mm_setzero_si128();
3776
+ int sumi1 = 0;
3777
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
3778
+ const __m128i q1b_1_0 = _mm_set_epi64x(iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
3779
+ const __m128i q1b_1_1 = _mm_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)]);
3780
+ const __m128i q1b_2_0 = _mm_set_epi64x(iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
3781
+ const __m128i q1b_2_1 = _mm_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)]);
3782
+ qs += 8;
3783
+ const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3784
+ const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3785
+ const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3786
+ const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3787
+
3788
+ const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0);
3789
+ const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1);
3790
+ const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0);
3791
+ const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1);
3792
+ const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
3793
+ const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
3794
+ const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(ls1));
3795
+ const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(ls1));
3796
+ const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(ls2));
3797
+ const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(ls2));
3798
+
3799
+ sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0));
3800
+ sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1));
3801
+ sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1
3802
+ + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;
3803
+ }
3804
+
3805
+ const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
3806
+ accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum);
3807
+ accum1 += d * sumi1;
3808
+
3809
+ }
3810
+
3811
+ *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
3812
+
3813
+ #else
3814
+
3815
+ float sumf = 0;
3816
+ for (int i = 0; i < nb; i++) {
3817
+
3818
+ const int8_t * q8 = y[i].qs;
3819
+ const uint8_t * qs = x[i].qs;
3820
+ const uint16_t * qh = x[i].qh;
3821
+
3822
+ int sumi = 0, sumi1 = 0;
3823
+ for (int ib = 0; ib < QK_K/32; ++ib) {
3824
+ const int ls = 2*((qh[ib] >> 12) & 7) + 1;
3825
+ const int delta = qh[ib] & 0x8000 ? -1 : 1;
3826
+ int lsum = 0;
3827
+ for (int l = 0; l < 4; ++l) {
3828
+ const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8)));
3829
+ for (int j = 0; j < 8; ++j) {
3830
+ lsum += q8[j] * grid[j];
3831
+ }
3832
+ q8 += 8;
3833
+ }
3834
+ sumi += ls * lsum;
3835
+ sumi1 += ls * delta * (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]);
3836
+ qs += 4;
3837
+ }
3838
+
3839
+ sumf += LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1);
3840
+ }
3841
+
3842
+ *s = sumf;
3843
+
3844
+ #endif
3845
+ }
3846
+
3847
+ void lm_ggml_vec_dot_iq1_m_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
3848
+ assert(n % QK_K == 0);
3849
+ assert(nrc == 1);
3850
+ UNUSED(nrc);
3851
+ UNUSED(bx);
3852
+ UNUSED(by);
3853
+ UNUSED(bs);
3854
+
3855
+ const block_iq1_m * LM_GGML_RESTRICT x = vx;
3856
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
3857
+
3858
+ const int nb = n / QK_K;
3859
+
3860
+ iq1m_scale_t scale;
3861
+
3862
+ #if defined __AVX2__
3863
+
3864
+ const __m256i mask = _mm256_set1_epi16(0x7);
3865
+ const __m256i mone = _mm256_set1_epi16(1);
3866
+ const __m256i mone8 = _mm256_set1_epi8(1);
3867
+ const __m256i mtwo8 = _mm256_set1_epi8(2);
3868
+ // VPSHUFB cannot cross 128-bit lanes so odd shifts go to upper half.
3869
+ const __m256i scales_shift = _mm256_set_epi64x(9, 3, 6, 0);
3870
+
3871
+ __m256 accum1 = _mm256_setzero_ps();
3872
+ __m256 accum2 = _mm256_setzero_ps();
3873
+ for (int i = 0; i < nb; ++i) {
3874
+
3875
+ const int8_t * q8 = y[i].qs;
3876
+ const uint8_t * qs = x[i].qs;
3877
+ const uint8_t * qh = x[i].qh;
3878
+ const uint16_t * sc = (const uint16_t *)x[i].scales;
3879
+
3880
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
3881
+ // Extract 3-bit scales (16 values)
3882
+ __m256i scales = _mm256_set1_epi64x(*(const uint64_t*)sc);
3883
+ scales = _mm256_srlv_epi64(scales, scales_shift);
3884
+ scales = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scales, mask), 1), mone);
3885
+
3886
+ // Indices to repeat each scale 8 times.
3887
+ __m256i scales_idx1 = _mm256_set1_epi16(0x0100);
3888
+ __m256i scales_idx2 = _mm256_add_epi8(scales_idx1, _mm256_set1_epi8(8));
3889
+
3890
+ __m256i sumi1 = _mm256_setzero_si256();
3891
+ __m256i sumi2 = _mm256_setzero_si256();
3892
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
3893
+ #ifdef __BMI2__
3894
+ const uint64_t packed_idx1 = _pdep_u64(*(const uint32_t *)qs, 0x00ff00ff00ff00ffULL)
3895
+ | _pdep_u64(*(const uint16_t*)(qh) & 0x7777, 0xf000f000f000f00ULL);
3896
+ const uint64_t packed_idx2 = _pdep_u64(*(const uint32_t *)(qs + 4), 0x00ff00ff00ff00ffULL)
3897
+ | _pdep_u64(*(const uint16_t*)(qh + 2) & 0x7777, 0xf000f000f000f00ULL);
3898
+ const uint16_t *idx1 = (const uint16_t *)(&packed_idx1);
3899
+ const uint16_t *idx2 = (const uint16_t *)(&packed_idx2);
3900
+ const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[idx1[3]], iq1s_grid[idx1[2]], iq1s_grid[idx1[1]], iq1s_grid[idx1[0]]);
3901
+ const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[idx2[3]], iq1s_grid[idx2[2]], iq1s_grid[idx2[1]], iq1s_grid[idx2[0]]);
3902
+
3903
+ // Convert signs to bytes 0x81 (negative) or 0x01 (positive)
3904
+ const uint64_t delta_sign = _pdep_u64(*(const uint32_t*)(qh) & 0x88888888, 0xf0f0f0f0f0f0f0f0ULL);
3905
+ const __m256i delta1 = _mm256_or_si256(mone8, _mm256_cvtepi8_epi64(_mm_set1_epi32(delta_sign)));
3906
+ const __m256i delta2 = _mm256_or_si256(mone8, _mm256_cvtepi8_epi64(_mm_set1_epi32(delta_sign >> 32)));
3907
+ #else
3908
+ const __m256i q1b_1 = _mm256_set_epi64x(
3909
+ iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)],
3910
+ iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]
3911
+ );
3912
+ const __m256i q1b_2 = _mm256_set_epi64x(
3913
+ iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)],
3914
+ iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]
3915
+ );
3916
+
3917
+ const __m256i delta1 = _mm256_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3918
+ qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,
3919
+ qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3920
+ qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
3921
+ const __m256i delta2 = _mm256_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3922
+ qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,
3923
+ qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3924
+ qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
3925
+ #endif
3926
+ const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
3927
+ const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
3928
+
3929
+ const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
3930
+ const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
3931
+ const __m256i dot3 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_1, delta1));
3932
+ const __m256i dot4 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_2, delta2));
3933
+
3934
+ __m256i scale1 = _mm256_shuffle_epi8(scales, scales_idx1);
3935
+ __m256i scale2 = _mm256_shuffle_epi8(scales, scales_idx2);
3936
+
3937
+ scales_idx1 = _mm256_add_epi8(scales_idx1, mtwo8);
3938
+ scales_idx2 = _mm256_add_epi8(scales_idx2, mtwo8);
3939
+
3940
+ const __m256i p1 = _mm256_madd_epi16(dot1, scale1);
3941
+ const __m256i p2 = _mm256_madd_epi16(dot2, scale2);
3942
+ const __m256i p3 = _mm256_madd_epi16(dot3, scale1);
3943
+ const __m256i p4 = _mm256_madd_epi16(dot4, scale2);
3944
+
3945
+ sumi1 = _mm256_add_epi32(sumi1, _mm256_add_epi32(p1, p2));
3946
+ sumi2 = _mm256_add_epi32(sumi2, _mm256_add_epi32(p3, p4));
3947
+
3948
+ qs += 8; qh += 4;
3949
+ }
3950
+
3951
+ const __m256 d = _mm256_set1_ps(y[i].d * LM_GGML_CPU_FP16_TO_FP32(scale.f16));
3952
+
3953
+ accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1);
3954
+ accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2);
3955
+ }
3956
+
3957
+ *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
3958
+
3959
+ #elif defined __AVX__
3960
+ const __m128i mask = _mm_set1_epi16(0x7);
3961
+ const __m128i mone = _mm_set1_epi16(1);
3962
+
3963
+ __m256 accum1 = _mm256_setzero_ps();
3964
+ __m256 accum2 = _mm256_setzero_ps();
3965
+ for (int i = 0; i < nb; ++i) {
3966
+
3967
+ const int8_t * q8 = y[i].qs;
3968
+ const uint8_t * qs = x[i].qs;
3969
+ const uint8_t * qh = x[i].qh;
3970
+ const uint16_t * sc = (const uint16_t *)x[i].scales;
3971
+
3972
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
3973
+
3974
+ __m128i sumi1_0 = _mm_setzero_si128();
3975
+ __m128i sumi1_1 = _mm_setzero_si128();
3976
+ __m128i sumi2_0 = _mm_setzero_si128();
3977
+ __m128i sumi2_1 = _mm_setzero_si128();
3978
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
3979
+ const __m128i q1b_1_0 = _mm_set_epi64x(
3980
+ iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]);
3981
+ const __m128i q1b_1_1 = _mm_set_epi64x(
3982
+ iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)]);
3983
+ const __m128i q1b_2_0 = _mm_set_epi64x(
3984
+ iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]);
3985
+ const __m128i q1b_2_1 = _mm_set_epi64x(
3986
+ iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)]);
3987
+ const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3988
+ const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3989
+ const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3990
+ const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3991
+
3992
+ const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0);
3993
+ const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1);
3994
+ const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0);
3995
+ const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1);
3996
+
3997
+ const __m128i delta1_0 = _mm_set_epi64x(qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3998
+ qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
3999
+ const __m128i delta1_1 = _mm_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
4000
+ qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
4001
+ const __m128i delta2_0 = _mm_set_epi64x(qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
4002
+ qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
4003
+ const __m128i delta2_1 = _mm_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
4004
+ qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
4005
+
4006
+ const __m128i dot3_0 = mul_add_epi8_sse(delta1_0, q8b_1_0);
4007
+ const __m128i dot3_1 = mul_add_epi8_sse(delta1_1, q8b_1_1);
4008
+ const __m128i dot4_0 = mul_add_epi8_sse(delta2_0, q8b_2_0);
4009
+ const __m128i dot4_1 = mul_add_epi8_sse(delta2_1, q8b_2_1);
4010
+
4011
+ __m128i scale1_0 = _mm_set1_epi16(sc[ib/2] >> 0);
4012
+ __m128i scale1_1 = _mm_set1_epi16(sc[ib/2] >> 3);
4013
+ __m128i scale2_0 = _mm_set1_epi16(sc[ib/2] >> 6);
4014
+ __m128i scale2_1 = _mm_set1_epi16(sc[ib/2] >> 9);
4015
+
4016
+ scale1_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_0, mask), 1), mone);
4017
+ scale1_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_1, mask), 1), mone);
4018
+ scale2_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_0, mask), 1), mone);
4019
+ scale2_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_1, mask), 1), mone);
4020
+ const __m128i p1_0 = _mm_madd_epi16(dot1_0, scale1_0);
4021
+ const __m128i p1_1 = _mm_madd_epi16(dot1_1, scale1_1);
4022
+ const __m128i p2_0 = _mm_madd_epi16(dot2_0, scale2_0);
4023
+ const __m128i p2_1 = _mm_madd_epi16(dot2_1, scale2_1);
4024
+ const __m128i p3_0 = _mm_madd_epi16(dot3_0, scale1_0);
4025
+ const __m128i p3_1 = _mm_madd_epi16(dot3_1, scale1_1);
4026
+ const __m128i p4_0 = _mm_madd_epi16(dot4_0, scale2_0);
4027
+ const __m128i p4_1 = _mm_madd_epi16(dot4_1, scale2_1);
4028
+
4029
+ sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0));
4030
+ sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1));
4031
+ sumi2_0 = _mm_add_epi32(sumi2_0, _mm_add_epi32(p3_0, p4_0));
4032
+ sumi2_1 = _mm_add_epi32(sumi2_1, _mm_add_epi32(p3_1, p4_1));
4033
+
4034
+ qs += 8; qh += 4;
4035
+ }
4036
+
4037
+ const __m256 d = _mm256_set1_ps(y[i].d * LM_GGML_CPU_FP16_TO_FP32(scale.f16));
4038
+
4039
+ accum1 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum1);
4040
+ accum2 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi2_1, sumi2_0))), accum2);
4041
+ }
4042
+
4043
+ *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
4044
+
4045
+ #else
4046
+
4047
+ int sum1[2], sum2[2], delta[4];
4048
+
4049
+ float sumf = 0;
4050
+ for (int i = 0; i < nb; i++) {
4051
+
4052
+ const int8_t * q8 = y[i].qs;
4053
+ const uint8_t * qs = x[i].qs;
4054
+ const uint8_t * qh = x[i].qh;
4055
+ const uint16_t * sc = (const uint16_t *)x[i].scales;
4056
+
4057
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
4058
+
4059
+ int sumi1 = 0, sumi2 = 0;
4060
+ for (int ib = 0; ib < QK_K/32; ++ib) {
4061
+ delta[0] = qh[0] & 0x08 ? -1 : 1;
4062
+ delta[1] = qh[0] & 0x80 ? -1 : 1;
4063
+ delta[2] = qh[1] & 0x08 ? -1 : 1;
4064
+ delta[3] = qh[1] & 0x80 ? -1 : 1;
4065
+ sum1[0] = sum1[1] = sum2[0] = sum2[1] = 0;
4066
+ for (int l = 0; l < 4; ++l) {
4067
+ const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((uint16_t)qh[l/2] << (8 - 4*(l%2))) & 0x700)));
4068
+ int lsum1 = 0, lsum2 = 0;
4069
+ for (int j = 0; j < 8; ++j) {
4070
+ lsum1 += q8[j] * grid[j];
4071
+ lsum2 += q8[j];
4072
+ }
4073
+ q8 += 8;
4074
+ sum1[l/2] += lsum1;
4075
+ sum2[l/2] += lsum2*delta[l];
4076
+ }
4077
+
4078
+ const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1;
4079
+ const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1;
4080
+
4081
+ sumi1 += sum1[0] * ls1 + sum1[1] * ls2;
4082
+ sumi2 += sum2[0] * ls1 + sum2[1] * ls2;
4083
+ qs += 4;
4084
+ qh += 2;
4085
+ }
4086
+
4087
+ sumf += LM_GGML_CPU_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2);
4088
+ }
4089
+
4090
+ *s = sumf;
4091
+
4092
+ #endif
4093
+ }
4094
+
4095
+ void lm_ggml_vec_dot_iq4_nl_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
4096
+ assert(nrc == 1);
4097
+ UNUSED(nrc);
4098
+ UNUSED(bx);
4099
+ UNUSED(by);
4100
+ UNUSED(bs);
4101
+ assert(n % QK4_NL == 0);
4102
+ static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same");
4103
+
4104
+ const block_iq4_nl * LM_GGML_RESTRICT x = vx;
4105
+ const block_q8_0 * LM_GGML_RESTRICT y = vy;
4106
+
4107
+ const int nb = n / QK4_NL;
4108
+
4109
+ int ib = 0;
4110
+ float sumf = 0;
4111
+
4112
+ #if defined __AVX2__
4113
+
4114
+ const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
4115
+ const __m128i m4b = _mm_set1_epi8(0x0f);
4116
+ const __m256i mone = _mm256_set1_epi16(1);
4117
+
4118
+ __m256 accum1 = _mm256_setzero_ps();
4119
+ __m256 accum2 = _mm256_setzero_ps();
4120
+ for (; ib + 1 < nb; ib += 2) {
4121
+ const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);
4122
+ const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);
4123
+ const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs);
4124
+ const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs);
4125
+ const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
4126
+ _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
4127
+ const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
4128
+ _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
4129
+ const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
4130
+ const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
4131
+ const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
4132
+ const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
4133
+ accum1 = _mm256_fmadd_ps(_mm256_set1_ps(LM_GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*LM_GGML_CPU_FP16_TO_FP32(x[ib + 0].d)),
4134
+ _mm256_cvtepi32_ps(p_1), accum1);
4135
+ accum2 = _mm256_fmadd_ps(_mm256_set1_ps(LM_GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*LM_GGML_CPU_FP16_TO_FP32(x[ib + 1].d)),
4136
+ _mm256_cvtepi32_ps(p_2), accum2);
4137
+ }
4138
+
4139
+ sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
4140
+
4141
+ #elif defined __AVX__
4142
+ const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
4143
+ const __m128i m4b = _mm_set1_epi8(0x0f);
4144
+
4145
+ __m256 accum = _mm256_setzero_ps();
4146
+ for (; ib + 1 < nb; ib += 2) {
4147
+ const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
4148
+ const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
4149
+ const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
4150
+ const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
4151
+ const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
4152
+ const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
4153
+
4154
+ const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
4155
+ const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
4156
+ const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
4157
+ const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
4158
+
4159
+ const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1);
4160
+ const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d);
4161
+ accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
4162
+ }
4163
+
4164
+ sumf = hsum_float_8(accum);
4165
+
4166
+ #endif
4167
+ for (; ib < nb; ++ib) {
4168
+ const float d = LM_GGML_CPU_FP16_TO_FP32(y[ib].d)*LM_GGML_CPU_FP16_TO_FP32(x[ib].d);
4169
+ int sumi1 = 0, sumi2 = 0;
4170
+ for (int j = 0; j < QK4_NL/2; ++j) {
4171
+ sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf];
4172
+ sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4];
4173
+ }
4174
+ sumf += d * (sumi1 + sumi2);
4175
+ }
4176
+ *s = sumf;
4177
+ }
4178
+
4179
+ void lm_ggml_vec_dot_iq4_xs_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
4180
+ assert(nrc == 1);
4181
+ UNUSED(nrc);
4182
+ UNUSED(bx);
4183
+ UNUSED(by);
4184
+ UNUSED(bs);
4185
+ assert(n % QK_K == 0);
4186
+
4187
+ const block_iq4_xs * LM_GGML_RESTRICT x = vx;
4188
+ const block_q8_K * LM_GGML_RESTRICT y = vy;
4189
+
4190
+ const int nb = n / QK_K;
4191
+
4192
+ #if defined __AVX2__
4193
+
4194
+ const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
4195
+ const __m128i m4b = _mm_set1_epi8(0x0f);
4196
+
4197
+ __m256 accum = _mm256_setzero_ps();
4198
+ for (int ibl = 0; ibl < nb; ++ibl) {
4199
+ const uint8_t * qs = x[ibl].qs;
4200
+ const int8_t * q8 = y[ibl].qs;
4201
+ uint16_t sh = x[ibl].scales_h;
4202
+ __m256i sumi1 = _mm256_setzero_si256();
4203
+ __m256i sumi2 = _mm256_setzero_si256();
4204
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
4205
+ const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)qs); qs += 16;
4206
+ const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)qs); qs += 16;
4207
+ const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
4208
+ const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
4209
+ const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
4210
+ _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
4211
+ const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
4212
+ _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
4213
+ const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
4214
+ const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
4215
+ const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
4216
+ const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32;
4217
+ sh >>= 4;
4218
+ const __m256i p_1 = _mm256_madd_epi16(p16_1, _mm256_set1_epi16(ls1));
4219
+ const __m256i p_2 = _mm256_madd_epi16(p16_2, _mm256_set1_epi16(ls2));
4220
+ sumi1 = _mm256_add_epi32(p_1, sumi1);
4221
+ sumi2 = _mm256_add_epi32(p_2, sumi2);
4222
+ }
4223
+ accum = _mm256_fmadd_ps(_mm256_set1_ps(LM_GGML_CPU_FP16_TO_FP32(x[ibl].d)*y[ibl].d),
4224
+ _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accum);
4225
+ }
4226
+
4227
+ *s = hsum_float_8(accum);
4228
+
4229
+ #elif defined __AVX__
4230
+ const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
4231
+ const __m128i m4b = _mm_set1_epi8(0x0f);
4232
+
4233
+ __m256 accum = _mm256_setzero_ps();
4234
+ for (int ibl = 0; ibl < nb; ++ibl) {
4235
+ const uint8_t * qs = x[ibl].qs;
4236
+ const int8_t * q8 = y[ibl].qs;
4237
+ uint16_t sh = x[ibl].scales_h;
4238
+ __m128i sumi1_0 = _mm_setzero_si128();
4239
+ __m128i sumi1_1 = _mm_setzero_si128();
4240
+ __m128i sumi2_0 = _mm_setzero_si128();
4241
+ __m128i sumi2_1 = _mm_setzero_si128();
4242
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
4243
+ const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)qs); qs += 16;
4244
+ const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)qs); qs += 16;
4245
+ const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
4246
+ const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
4247
+ const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
4248
+ const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
4249
+ const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
4250
+ const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
4251
+ const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
4252
+ const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
4253
+ const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
4254
+ const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
4255
+ const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
4256
+ const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
4257
+ const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
4258
+ const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32;
4259
+ sh >>= 4;
4260
+ const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, _mm_set1_epi16(ls1));
4261
+ const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, _mm_set1_epi16(ls1));
4262
+ const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, _mm_set1_epi16(ls2));
4263
+ const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, _mm_set1_epi16(ls2));
4264
+ sumi1_0 = _mm_add_epi32(p_1_0, sumi1_0);
4265
+ sumi1_1 = _mm_add_epi32(p_1_1, sumi1_1);
4266
+ sumi2_0 = _mm_add_epi32(p_2_0, sumi2_0);
4267
+ sumi2_1 = _mm_add_epi32(p_2_1, sumi2_1);
4268
+ }
4269
+ __m128i sumi12_0 = _mm_add_epi32(sumi1_0, sumi2_0);
4270
+ __m128i sumi12_1 = _mm_add_epi32(sumi1_1, sumi2_1);
4271
+ accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(LM_GGML_CPU_FP16_TO_FP32(x[ibl].d)*y[ibl].d),
4272
+ _mm256_cvtepi32_ps(MM256_SET_M128I(sumi12_1, sumi12_0))), accum);
4273
+ }
4274
+
4275
+ *s = hsum_float_8(accum);
4276
+
4277
+ #else
4278
+ float sumf = 0;
4279
+ for (int ibl = 0; ibl < nb; ++ibl) {
4280
+ const float d4d8 = LM_GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d;
4281
+ uint16_t h = x[ibl].scales_h;
4282
+ const uint8_t * qs = x[ibl].qs;
4283
+ const int8_t * q8 = y[ibl].qs;
4284
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
4285
+ const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30);
4286
+ const uint8_t ls2 = (x[ibl].scales_l[ib/2] >> 4) | ((h << 2) & 0x30);
4287
+ h >>= 4;
4288
+ const float d1 = d4d8*(ls1 - 32);
4289
+ const float d2 = d4d8*(ls2 - 32);
4290
+ int sumi1 = 0, sumi2 = 0;
4291
+ for (int j = 0; j < 16; ++j) {
4292
+ sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
4293
+ sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4];
4294
+ }
4295
+ sumf += d1 * (sumi1 + sumi2);
4296
+ qs += 16;
4297
+ q8 += 32;
4298
+ sumi1 = sumi2 = 0;
4299
+ for (int j = 0; j < 16; ++j) {
4300
+ sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
4301
+ sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4];
4302
+ }
4303
+ sumf += d2 * (sumi1 + sumi2);
4304
+ qs += 16;
4305
+ q8 += 32;
4306
+ }
4307
+ }
4308
+ *s = sumf;
4309
+ #endif
4310
+ }
4311
+