whispercpp 1.3.2 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +4 -2
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  19. data/ext/sources/examples/addon.node/addon.cpp +150 -31
  20. data/ext/sources/examples/addon.node/index.js +3 -0
  21. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  22. data/ext/sources/examples/bench/bench.cpp +3 -2
  23. data/ext/sources/examples/cli/cli.cpp +3 -2
  24. data/ext/sources/examples/command/command.cpp +32 -8
  25. data/ext/sources/examples/common-whisper.cpp +14 -7
  26. data/ext/sources/examples/lsp/lsp.cpp +2 -0
  27. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  28. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  29. data/ext/sources/examples/server/server.cpp +169 -22
  30. data/ext/sources/examples/stream/stream.cpp +6 -0
  31. data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
  32. data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
  33. data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
  34. data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
  35. data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
  36. data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
  37. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  38. data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
  39. data/ext/sources/examples/talk-llama/llama-context.h +38 -17
  40. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  41. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
  42. data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
  43. data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
  44. data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
  45. data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
  48. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
  49. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
  50. data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
  51. data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
  52. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
  53. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
  54. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
  55. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
  56. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  57. data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
  58. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  59. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
  60. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  61. data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
  62. data/ext/sources/examples/talk-llama/llama-model.h +27 -0
  63. data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
  66. data/ext/sources/examples/talk-llama/llama.cpp +11 -7
  67. data/ext/sources/examples/talk-llama/llama.h +147 -40
  68. data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
  69. data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
  70. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  71. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
  72. data/ext/sources/ggml/CMakeLists.txt +48 -3
  73. data/ext/sources/ggml/cmake/common.cmake +24 -0
  74. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  75. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  76. data/ext/sources/ggml/include/ggml.h +144 -5
  77. data/ext/sources/ggml/src/CMakeLists.txt +82 -24
  78. data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
  79. data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
  80. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  81. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  82. data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
  83. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  84. data/ext/sources/ggml/src/ggml-common.h +4 -0
  85. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
  86. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  87. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  91. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  92. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  99. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  101. data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  103. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  105. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  106. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  107. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  108. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  109. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
  110. data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
  112. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  113. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
  115. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
  116. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  117. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
  118. data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
  119. data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
  120. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  121. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  122. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  123. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  124. data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
  125. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  129. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
  130. data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
  131. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  132. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
  133. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
  134. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  135. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
  136. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  137. data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
  138. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
  139. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  140. data/ext/sources/ggml/src/ggml-impl.h +127 -183
  141. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  142. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
  143. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
  144. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
  145. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  146. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
  147. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
  148. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  149. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  150. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  151. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
  152. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  153. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  154. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  155. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  156. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  157. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  158. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  159. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  160. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  161. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  162. data/ext/sources/ggml/src/ggml-quants.c +6 -8
  163. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  164. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  165. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  166. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  167. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
  168. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
  169. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
  170. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
  171. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  172. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  173. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  174. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
  175. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  176. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  177. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
  178. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
  179. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  180. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  181. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
  182. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  183. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
  184. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
  185. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
  186. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  187. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  188. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  189. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  190. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  191. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
  192. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  193. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
  194. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  195. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  196. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  197. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  198. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  199. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  200. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  201. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  202. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
  203. data/ext/sources/ggml/src/ggml.c +328 -48
  204. data/ext/sources/ggml/src/ggml.cpp +26 -0
  205. data/ext/sources/ggml/src/gguf.cpp +24 -3
  206. data/ext/sources/include/whisper.h +2 -0
  207. data/ext/sources/src/CMakeLists.txt +2 -0
  208. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  209. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  210. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  211. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  212. data/ext/sources/src/whisper.cpp +218 -169
  213. data/extsources.rb +15 -9
  214. data/lib/whisper/context.rb +15 -0
  215. data/lib/whisper/model/uri.rb +56 -1
  216. data/lib/whisper/segment.rb +58 -0
  217. data/sig/whisper.rbs +68 -38
  218. data/{tests → test}/helper.rb +1 -12
  219. data/{tests → test}/test_model.rb +9 -0
  220. data/test/test_package.rb +51 -0
  221. data/test/test_segment.rb +146 -0
  222. data/{tests → test}/test_whisper.rb +70 -0
  223. data/whispercpp.gemspec +2 -3
  224. metadata +91 -43
  225. data/ext/sources/.dockerignore +0 -3
  226. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  227. data/ext/sources/ci/run.sh +0 -336
  228. data/ext/sources/close-issue.yml +0 -28
  229. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
  230. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  231. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  232. data/tests/test_package.rb +0 -46
  233. data/tests/test_segment.rb +0 -74
  234. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  235. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  236. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  237. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  238. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  239. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  240. /data/{tests → test}/test_callback.rb +0 -0
  241. /data/{tests → test}/test_error.rb +0 -0
  242. /data/{tests → test}/test_params.rb +0 -0
  243. /data/{tests → test}/test_vad.rb +0 -0
  244. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -35,6 +35,17 @@ constexpr constant static float kvalues_iq4nl_f[16] = {
35
35
  -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
36
36
  };
37
37
 
38
+ static inline int best_index_int8(int n, constant float * val, float x) {
39
+ if (x <= val[0]) return 0;
40
+ if (x >= val[n-1]) return n-1;
41
+ int ml = 0, mu = n-1;
42
+ while (mu-ml > 1) {
43
+ int mav = (ml+mu)/2;
44
+ if (x < val[mav]) mu = mav; else ml = mav;
45
+ }
46
+ return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
47
+ }
48
+
38
49
  // NOTE: this is not dequantizing - we are simply fitting the template
39
50
  template <typename type4x4>
40
51
  void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
@@ -97,6 +108,176 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
97
108
  }
98
109
  }
99
110
 
111
+ void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
112
+ float amax = 0.0f; // absolute max
113
+ float max = 0.0f;
114
+
115
+ for (int j = 0; j < QK4_0; j++) {
116
+ const float v = src[j];
117
+ if (amax < fabs(v)) {
118
+ amax = fabs(v);
119
+ max = v;
120
+ }
121
+ }
122
+
123
+ const float d = max / -8;
124
+ const float id = d ? 1.0f/d : 0.0f;
125
+
126
+ dst.d = d;
127
+
128
+ for (int j = 0; j < QK4_0/2; ++j) {
129
+ const float x0 = src[0 + j]*id;
130
+ const float x1 = src[QK4_0/2 + j]*id;
131
+
132
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
133
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
134
+
135
+ dst.qs[j] = xi0;
136
+ dst.qs[j] |= xi1 << 4;
137
+ }
138
+ }
139
+
140
+ void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
141
+ #pragma METAL fp math_mode(safe)
142
+ float min = FLT_MAX;
143
+ float max = -FLT_MAX;
144
+
145
+ for (int j = 0; j < QK4_1; j++) {
146
+ const float v = src[j];
147
+ if (min > v) min = v;
148
+ if (max < v) max = v;
149
+ }
150
+
151
+ const float d = (max - min) / ((1 << 4) - 1);
152
+ const float id = d ? 1.0f/d : 0.0f;
153
+
154
+ dst.d = d;
155
+ dst.m = min;
156
+
157
+ for (int j = 0; j < QK4_1/2; ++j) {
158
+ const float x0 = (src[0 + j] - min)*id;
159
+ const float x1 = (src[QK4_1/2 + j] - min)*id;
160
+
161
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
162
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
163
+
164
+ dst.qs[j] = xi0;
165
+ dst.qs[j] |= xi1 << 4;
166
+ }
167
+ }
168
+
169
+ void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
170
+ float amax = 0.0f; // absolute max
171
+ float max = 0.0f;
172
+
173
+ for (int j = 0; j < QK5_0; j++) {
174
+ const float v = src[j];
175
+ if (amax < fabs(v)) {
176
+ amax = fabs(v);
177
+ max = v;
178
+ }
179
+ }
180
+
181
+ const float d = max / -16;
182
+ const float id = d ? 1.0f/d : 0.0f;
183
+
184
+ dst.d = d;
185
+
186
+ uint32_t qh = 0;
187
+ for (int j = 0; j < QK5_0/2; ++j) {
188
+ const float x0 = src[0 + j]*id;
189
+ const float x1 = src[QK5_0/2 + j]*id;
190
+
191
+ const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
192
+ const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
193
+
194
+ dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
195
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
196
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
197
+ }
198
+
199
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
200
+
201
+ for (int j = 0; j < 4; ++j) {
202
+ dst.qh[j] = qh8[j];
203
+ }
204
+ }
205
+
206
+ void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
207
+ #pragma METAL fp math_mode(safe)
208
+ float max = src[0];
209
+ float min = src[0];
210
+
211
+ for (int j = 1; j < QK5_1; j++) {
212
+ const float v = src[j];
213
+ min = v < min ? v : min;
214
+ max = v > max ? v : max;
215
+ }
216
+
217
+ const float d = (max - min) / 31;
218
+ const float id = d ? 1.0f/d : 0.0f;
219
+
220
+ dst.d = d;
221
+ dst.m = min;
222
+
223
+ uint32_t qh = 0;
224
+ for (int j = 0; j < QK5_1/2; ++j) {
225
+ const float x0 = (src[0 + j] - min)*id;
226
+ const float x1 = (src[QK5_1/2 + j] - min)*id;
227
+
228
+ const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
229
+ const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
230
+
231
+ dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
232
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
233
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
234
+ }
235
+
236
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
237
+
238
+ for (int j = 0; j < 4; ++j) {
239
+ dst.qh[j] = qh8[j];
240
+ }
241
+ }
242
+
243
+ void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
244
+ #pragma METAL fp math_mode(safe)
245
+ float amax = 0.0f; // absolute max
246
+ float max = 0.0f;
247
+
248
+ for (int j = 0; j < QK4_NL; j++) {
249
+ const float v = src[j];
250
+ if (amax < fabs(v)) {
251
+ amax = fabs(v);
252
+ max = v;
253
+ }
254
+ }
255
+
256
+ const float d = max / kvalues_iq4nl_f[0];
257
+ const float id = d ? 1.0f/d : 0.0f;
258
+
259
+ float sumqx = 0, sumq2 = 0;
260
+ for (int j = 0; j < QK4_NL/2; ++j) {
261
+ const float x0 = src[0 + j]*id;
262
+ const float x1 = src[QK4_NL/2 + j]*id;
263
+
264
+ const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
265
+ const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
266
+
267
+ dst.qs[j] = xi0 | (xi1 << 4);
268
+
269
+ const float v0 = kvalues_iq4nl_f[xi0];
270
+ const float v1 = kvalues_iq4nl_f[xi1];
271
+ const float w0 = src[0 + j]*src[0 + j];
272
+ const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
273
+ sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
274
+ sumq2 += w0*v0*v0 + w1*v1*v1;
275
+
276
+ }
277
+
278
+ dst.d = sumq2 > 0 ? sumqx/sumq2 : d;
279
+ }
280
+
100
281
  template <typename type4x4>
101
282
  void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
102
283
  device const uint16_t * qs = ((device const uint16_t *)xb + 2);
@@ -279,6 +460,26 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
279
460
  }
280
461
  }
281
462
 
463
+ void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
464
+ float amax = 0.0f; // absolute max
465
+
466
+ for (int j = 0; j < QK8_0; j++) {
467
+ const float v = src[j];
468
+ amax = MAX(amax, fabs(v));
469
+ }
470
+
471
+ const float d = amax / ((1 << 7) - 1);
472
+ const float id = d ? 1.0f/d : 0.0f;
473
+
474
+ dst.d = d;
475
+
476
+ for (int j = 0; j < QK8_0; ++j) {
477
+ const float x0 = src[j]*id;
478
+
479
+ dst.qs[j] = round(x0);
480
+ }
481
+ }
482
+
282
483
  template <typename type4x4>
283
484
  void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
284
485
  const float d = xb->d;
@@ -993,31 +1194,125 @@ kernel void kernel_neg(
993
1194
  dst[tpig] = -src0[tpig];
994
1195
  }
995
1196
 
1197
+ kernel void kernel_reglu(
1198
+ device const char * src0,
1199
+ device const char * src1,
1200
+ device char * dst,
1201
+ constant ggml_metal_kargs_glu & args,
1202
+ uint tgpig[[threadgroup_position_in_grid]],
1203
+ uint tpitg[[thread_position_in_threadgroup]],
1204
+ uint ntg[[threads_per_threadgroup]]) {
1205
+ device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1206
+ device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1207
+ device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1208
+
1209
+ for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1210
+ const float x0 = src0_row[i0];
1211
+ const float x1 = src1_row[i0];
1212
+
1213
+ dst_row[i0] = x0*x1*(x0 > 0.0f);
1214
+ }
1215
+ }
1216
+
1217
+ kernel void kernel_geglu(
1218
+ device const char * src0,
1219
+ device const char * src1,
1220
+ device char * dst,
1221
+ constant ggml_metal_kargs_glu & args,
1222
+ uint tgpig[[threadgroup_position_in_grid]],
1223
+ uint tpitg[[thread_position_in_threadgroup]],
1224
+ uint ntg[[threads_per_threadgroup]]) {
1225
+ device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1226
+ device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1227
+ device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1228
+
1229
+ for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1230
+ const float x0 = src0_row[i0];
1231
+ const float x1 = src1_row[i0];
1232
+
1233
+ const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
1234
+
1235
+ dst_row[i0] = gelu*x1;
1236
+ }
1237
+ }
1238
+
1239
+ kernel void kernel_swiglu(
1240
+ device const char * src0,
1241
+ device const char * src1,
1242
+ device char * dst,
1243
+ constant ggml_metal_kargs_glu & args,
1244
+ uint tgpig[[threadgroup_position_in_grid]],
1245
+ uint tpitg[[thread_position_in_threadgroup]],
1246
+ uint ntg[[threads_per_threadgroup]]) {
1247
+ device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1248
+ device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1249
+ device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1250
+
1251
+ for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1252
+ const float x0 = src0_row[i0];
1253
+ const float x1 = src1_row[i0];
1254
+
1255
+ const float silu = x0 / (1.0f + exp(-x0));
1256
+
1257
+ dst_row[i0] = silu*x1;
1258
+ }
1259
+ }
1260
+
1261
+ template <bool norm>
996
1262
  kernel void kernel_sum_rows(
1263
+ constant ggml_metal_kargs_sum_rows & args,
997
1264
  device const float * src0,
998
1265
  device float * dst,
999
- constant ggml_metal_kargs_sum_rows & args,
1000
- uint3 tpig[[thread_position_in_grid]]) {
1001
- int64_t i3 = tpig.z;
1002
- int64_t i2 = tpig.y;
1003
- int64_t i1 = tpig.x;
1266
+ threadgroup float * shmem_f32 [[threadgroup(0)]],
1267
+ uint3 tgpig[[threadgroup_position_in_grid]],
1268
+ ushort3 tpitg[[thread_position_in_threadgroup]],
1269
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
1270
+ ushort tiisg[[thread_index_in_simdgroup]],
1271
+ ushort3 ntg[[threads_per_threadgroup]]) {
1272
+ int64_t i3 = tgpig.z;
1273
+ int64_t i2 = tgpig.y;
1274
+ int64_t i1 = tgpig.x;
1004
1275
 
1005
1276
  if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
1006
1277
  return;
1007
1278
  }
1008
1279
 
1280
+ if (sgitg == 0) {
1281
+ shmem_f32[tiisg] = 0.0f;
1282
+ }
1283
+
1009
1284
  device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
1010
1285
  device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
1011
1286
 
1012
- float row_sum = 0;
1287
+ float sumf = 0;
1013
1288
 
1014
- for (int64_t i0 = 0; i0 < args.ne00; i0++) {
1015
- row_sum += src_row[i0];
1289
+ for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
1290
+ sumf += src_row[i0];
1016
1291
  }
1017
1292
 
1018
- dst_row[0] = row_sum;
1293
+ sumf = simd_sum(sumf);
1294
+
1295
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1296
+
1297
+ if (tiisg == 0) {
1298
+ shmem_f32[sgitg] = sumf;
1299
+ }
1300
+
1301
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1302
+
1303
+ sumf = shmem_f32[tiisg];
1304
+ sumf = simd_sum(sumf);
1305
+
1306
+ if (tpitg.x == 0) {
1307
+ dst_row[0] = norm ? sumf / args.ne00 : sumf;
1308
+ }
1019
1309
  }
1020
1310
 
1311
+ typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
1312
+
1313
+ template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
1314
+ template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
1315
+
1021
1316
  template<typename T>
1022
1317
  kernel void kernel_soft_max(
1023
1318
  device const char * src0,
@@ -2502,6 +2797,70 @@ template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv<
2502
2797
  template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
2503
2798
  #endif
2504
2799
 
2800
+ template<typename T04, typename T14, typename args_t>
2801
+ void kernel_mul_mv_c4_impl(
2802
+ args_t args,
2803
+ device const char * src0,
2804
+ device const char * src1,
2805
+ device char * dst,
2806
+ uint3 tgpig,
2807
+ ushort tiisg) {
2808
+ const int r0 = tgpig.x*32 + tiisg;
2809
+ const int rb = tgpig.y*N_MV_T_T;
2810
+ const int im = tgpig.z;
2811
+
2812
+ if (r0 >= args.ne01) {
2813
+ return;
2814
+ }
2815
+
2816
+ const uint i12 = im%args.ne12;
2817
+ const uint i13 = im/args.ne12;
2818
+
2819
+ const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
2820
+
2821
+ device const T04 * x = (device const T04 *) (src0 + offset0);
2822
+
2823
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
2824
+
2825
+ for (int row = 0; row < N_MV_T_T; ++row) {
2826
+ int r1 = rb + row;
2827
+ if (r1 >= args.ne11) {
2828
+ break;
2829
+ }
2830
+
2831
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
2832
+
2833
+ device const T14 * y = (device const T14 *) (src1 + offset1);
2834
+
2835
+ dst_f32[(uint64_t)r1*args.ne0 + r0] = dot((float4) x[0], (float4) y[0]);
2836
+ }
2837
+ }
2838
+
2839
+ template<typename T04, typename T14>
2840
+ kernel void kernel_mul_mv_c4(
2841
+ constant ggml_metal_kargs_mul_mv & args,
2842
+ device const char * src0,
2843
+ device const char * src1,
2844
+ device char * dst,
2845
+ uint3 tgpig[[threadgroup_position_in_grid]],
2846
+ ushort tiisg[[thread_index_in_simdgroup]]) {
2847
+ kernel_mul_mv_c4_impl<T04, T14, constant ggml_metal_kargs_mul_mv &>(
2848
+ args,
2849
+ src0,
2850
+ src1,
2851
+ dst,
2852
+ tgpig,
2853
+ tiisg);
2854
+ }
2855
+
2856
+ typedef decltype(kernel_mul_mv_c4<half4, half4>) mul_mv_c4_t;
2857
+
2858
+ template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<float4, float4>;
2859
+ template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, float4>;
2860
+ #if defined(GGML_METAL_USE_BF16)
2861
+ template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>;
2862
+ #endif
2863
+
2505
2864
  template<typename T, typename T4>
2506
2865
  kernel void kernel_mul_mv_1row(
2507
2866
  constant ggml_metal_kargs_mul_mv & args,
@@ -3328,14 +3687,12 @@ kernel void kernel_flash_attn_ext(
3328
3687
  constexpr short NW = N_SIMDWIDTH;
3329
3688
  constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
3330
3689
 
3331
- const short TS = nsg*SH; // shared memory size per query in (s_t == float)
3332
- const short T = DK + 2*TS; // shared memory size per query in (half)
3690
+ const short TS = nsg*SH; // shared memory size per query in (s_t == float)
3691
+ const short T = 2*DK + 2*TS; // shared memory size per query in (half)
3333
3692
 
3334
- threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3335
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3336
- threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation
3337
- threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
3338
- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix
3693
+ threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3694
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3695
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix
3339
3696
 
3340
3697
  threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
3341
3698
  threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
@@ -3354,7 +3711,7 @@ kernel void kernel_flash_attn_ext(
3354
3711
  if (iq1 + j < args.ne01) {
3355
3712
  sq4[j*DK4 + i] = (q4_t) q4[i];
3356
3713
  } else {
3357
- sq4[j*DK4 + i] = (q4_t) 0.0f;
3714
+ sq4[j*DK4 + i] = 0;
3358
3715
  }
3359
3716
  }
3360
3717
  }
@@ -3548,20 +3905,20 @@ kernel void kernel_flash_attn_ext(
3548
3905
 
3549
3906
  // O = diag(ms)*O
3550
3907
  {
3551
- s8x8_t mm;
3552
- simdgroup_load(mm, ss + 2*C, TS, 0, false);
3908
+ s8x8_t ms;
3909
+ simdgroup_load(ms, ss + 2*C, TS, 0, false);
3553
3910
 
3554
3911
  #pragma unroll(DV8)
3555
3912
  for (short i = 0; i < DV8; ++i) {
3556
- simdgroup_multiply(lo[i], mm, lo[i]);
3913
+ simdgroup_multiply(lo[i], ms, lo[i]);
3557
3914
  }
3558
3915
  }
3559
3916
 
3560
3917
  // O = O + (Q*K^T)*V
3561
3918
  {
3562
3919
  for (short cc = 0; cc < C/8; ++cc) {
3563
- s8x8_t ms;
3564
- simdgroup_load(ms, ss + 8*cc, TS, 0, false);
3920
+ s8x8_t vs;
3921
+ simdgroup_load(vs, ss + 8*cc, TS, 0, false);
3565
3922
 
3566
3923
  if (is_same<vd4x4_t, v4x4_t>::value) {
3567
3924
  // we can read directly from global memory
@@ -3572,7 +3929,7 @@ kernel void kernel_flash_attn_ext(
3572
3929
  v8x8_t mv;
3573
3930
  simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20
3574
3931
 
3575
- simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
3932
+ simdgroup_multiply_accumulate(lo[i], vs, mv, lo[i]);
3576
3933
  }
3577
3934
  } else {
3578
3935
  for (short ii = 0; ii < DV16; ii += 4) {
@@ -3593,10 +3950,10 @@ kernel void kernel_flash_attn_ext(
3593
3950
  v8x8_t mv;
3594
3951
 
3595
3952
  simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
3596
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
3953
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
3597
3954
 
3598
3955
  simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
3599
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
3956
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
3600
3957
  }
3601
3958
  } else {
3602
3959
  if (ii + tx < DV16) {
@@ -3611,10 +3968,10 @@ kernel void kernel_flash_attn_ext(
3611
3968
  v8x8_t mv;
3612
3969
 
3613
3970
  simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
3614
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
3971
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
3615
3972
 
3616
3973
  simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
3617
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
3974
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
3618
3975
  }
3619
3976
  }
3620
3977
  }
@@ -3624,93 +3981,89 @@ kernel void kernel_flash_attn_ext(
3624
3981
  }
3625
3982
 
3626
3983
  // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
3627
- for (short j = 0; j < Q; ++j) {
3628
- if (tiisg == 0) {
3629
- ss[j*TS + 0] = S[j];
3630
- ss[j*TS + 1] = M[j];
3631
- }
3984
+ for (short j = tiisg; j < Q; j += NW) {
3985
+ ss[j*TS + 0] = S[j];
3986
+ ss[j*TS + 1] = M[j];
3632
3987
  }
3633
3988
  }
3634
3989
 
3635
- // reduce the warps sequentially
3636
- for (ushort sg = 1; sg < nsg; ++sg) {
3637
- float S = { 0.0f };
3638
- float M = { -__FLT_MAX__/2 };
3990
+ threadgroup_barrier(mem_flags::mem_threadgroup);
3639
3991
 
3640
- threadgroup_barrier(mem_flags::mem_threadgroup);
3992
+ threadgroup float * so = (threadgroup float *) (shmem_f16 + 0*DK); // reuse query data for accumulation
3993
+ threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0*DK);
3641
3994
 
3642
- // each simdgroup stores its output to shared memory, reusing sq
3643
- if (sgitg == sg) {
3644
- for (short i = 0; i < DV8; ++i) {
3645
- simdgroup_store(lo[i], so + i*8, DV, 0, false);
3646
- }
3995
+ // store result to shared memory in F32
3996
+ if (sgitg == 0) {
3997
+ for (short i = 0; i < DV8; ++i) {
3998
+ //simdgroup_store(lo[i], so + i*8, DV, 0, false);
3999
+ simdgroup_float8x8 t(1.0f);
4000
+ simdgroup_multiply(t, lo[i], t);
4001
+ simdgroup_store(t, so + i*8, DV, 0, false);
3647
4002
  }
4003
+ }
3648
4004
 
3649
- threadgroup_barrier(mem_flags::mem_threadgroup);
4005
+ threadgroup_barrier(mem_flags::mem_threadgroup);
3650
4006
 
3651
- // the first simdgroup accumulates the results from the other simdgroups
3652
- if (sgitg == 0) {
3653
- for (short j = 0; j < Q; ++j) {
3654
- const float S0 = ss[j*TS + 0];
3655
- const float S1 = ss[j*TS + sg*SH + 0];
4007
+ // reduce the warps sequentially
4008
+ for (ushort sg = 1; sg < nsg; ++sg) {
4009
+ if (sgitg == sg) {
4010
+ for (short j = tiisg; j < Q; j += NW) {
4011
+ const float S0 = ss[j*TS - 1*SH + 0];
4012
+ const float S1 = ss[j*TS + 0];
3656
4013
 
3657
- const float M0 = ss[j*TS + 1];
3658
- const float M1 = ss[j*TS + sg*SH + 1];
4014
+ const float M0 = ss[j*TS - 1*SH + 1];
4015
+ const float M1 = ss[j*TS + 1];
3659
4016
 
3660
- M = max(M0, M1);
4017
+ const float M = max(M0, M1);
3661
4018
 
3662
- const float ms0 = exp(M0 - M);
3663
- const float ms1 = exp(M1 - M);
4019
+ float ms0 = exp(M0 - M);
4020
+ float ms1 = exp(M1 - M);
3664
4021
 
3665
- S = S0*ms0 + S1*ms1;
4022
+ const float S = S0*ms0 + S1*ms1;
3666
4023
 
3667
- if (tiisg == 0) {
3668
- ss[j*TS + 0] = S;
3669
- ss[j*TS + 1] = M;
4024
+ ss[j*TS + 0] = S;
4025
+ ss[j*TS + 1] = M;
3670
4026
 
3671
- ss[j*TS + 2*C + j ] = ms0;
3672
- ss[j*TS + 2*C + j + sg*SH] = ms1;
3673
- }
4027
+ ss[j*TS + 2*C + j - 1*SH] = ms0;
4028
+ ss[j*TS + 2*C + j ] = ms1;
3674
4029
  }
3675
4030
 
4031
+ //simdgroup_barrier(mem_flags::mem_threadgroup);
4032
+
3676
4033
  // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
3677
4034
  {
3678
4035
  s8x8_t ms0;
3679
4036
  s8x8_t ms1;
3680
4037
 
3681
- simdgroup_load(ms0, ss + 2*C, TS, 0, false);
3682
- simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
4038
+ simdgroup_load(ms0, ss + 2*C - 1*SH, TS, 0, false);
4039
+ simdgroup_load(ms1, ss + 2*C, TS, 0, false);
3683
4040
 
3684
4041
  #pragma unroll(DV8)
3685
4042
  for (short i = 0; i < DV8; ++i) {
3686
- o8x8_t t;
4043
+ simdgroup_float8x8 t;
3687
4044
 
3688
4045
  simdgroup_load (t, so + i*8, DV, 0, false);
3689
- simdgroup_multiply(t, ms1, t);
4046
+ simdgroup_multiply(t, ms0, t);
3690
4047
 
3691
- simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
4048
+ simdgroup_multiply_accumulate(t, ms1, lo[i], t);
4049
+ simdgroup_store(t, so + i*8, DV, 0, false);
3692
4050
  }
3693
4051
  }
3694
4052
  }
3695
- }
3696
4053
 
3697
- // store result to shared memory (reuse sq)
3698
- if (sgitg == 0) {
3699
- for (short i = 0; i < DV8; ++i) {
3700
- simdgroup_store(lo[i], so + i*8, DV, 0, false);
3701
- }
4054
+ threadgroup_barrier(mem_flags::mem_threadgroup);
3702
4055
  }
3703
4056
 
3704
- device float4 * dst4 = (device float4 *) dst;
4057
+ threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*(nsg-1)*SH + 2*Q*DK);
3705
4058
 
3706
4059
  // final rescale with 1/S and store to global memory
3707
- if (sgitg == 0) {
3708
- for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
3709
- const float S = ss[j*TS + 0];
4060
+ for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) {
4061
+ const float S = 1.0f/sf[j*TS + 0];
3710
4062
 
3711
- for (short i = tiisg; i < DV4; i += NW) {
3712
- dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S;
3713
- }
4063
+ device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
4064
+
4065
+ for (short i = tiisg; i < DV4; i += NW) {
4066
+ dst4[i] = (float4) so4[j*DV4 + i]*S;
3714
4067
  }
3715
4068
  }
3716
4069
  }
@@ -3719,12 +4072,22 @@ kernel void kernel_flash_attn_ext(
3719
4072
  // template to be able to explore different combinations
3720
4073
  //
3721
4074
  #define FA_TYPES \
3722
- half, half4, simdgroup_half8x8, \
3723
- half, half4x4, simdgroup_half8x8, \
3724
- half, half4x4, simdgroup_half8x8, \
3725
- float, simdgroup_float8x8, \
3726
- float, simdgroup_float8x8, \
3727
- half, half4, simdgroup_half8x8
4075
+ float, float4, simdgroup_float8x8, \
4076
+ half, half4x4, simdgroup_half8x8, \
4077
+ half, half4x4, simdgroup_half8x8, \
4078
+ float, simdgroup_float8x8, \
4079
+ float, simdgroup_float8x8, \
4080
+ half, half4, simdgroup_half8x8
4081
+ //float, float4, simdgroup_float8x8
4082
+
4083
+ #define FA_TYPES_BF \
4084
+ bfloat, bfloat4, simdgroup_bfloat8x8, \
4085
+ bfloat, bfloat4x4, simdgroup_bfloat8x8, \
4086
+ bfloat, bfloat4x4, simdgroup_bfloat8x8, \
4087
+ float, simdgroup_float8x8, \
4088
+ float, simdgroup_float8x8, \
4089
+ half, half4, simdgroup_half8x8
4090
+ //float, float4, simdgroup_float8x8
3728
4091
 
3729
4092
  typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
3730
4093
 
@@ -3739,15 +4102,15 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_at
3739
4102
  template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
3740
4103
 
3741
4104
  #if defined(GGML_METAL_USE_BF16)
3742
- template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
3743
- template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
3744
- template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
3745
- template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
3746
- template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
3747
- template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
3748
- template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
3749
- template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
3750
- template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
4105
+ template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
4106
+ template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
4107
+ template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
4108
+ template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
4109
+ template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
4110
+ template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
4111
+ template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
4112
+ template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
4113
+ template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
3751
4114
  #endif
3752
4115
 
3753
4116
  template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
@@ -3801,6 +4164,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_at
3801
4164
  template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
3802
4165
 
3803
4166
  #undef FA_TYPES
4167
+ #undef FA_TYPES_BF
3804
4168
 
3805
4169
  template<
3806
4170
  typename q4_t, // query types in shared memory
@@ -3847,12 +4211,12 @@ kernel void kernel_flash_attn_ext_vec(
3847
4211
 
3848
4212
  const short T = DK + nsg*SH; // shared memory size per query in (half)
3849
4213
 
3850
- //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3851
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3852
- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
3853
- threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
3854
- threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
3855
- threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
4214
+ //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
4215
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
4216
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
4217
+ threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
4218
+ threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
4219
+ threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*DV + Q*T); // scratch buffer for the results
3856
4220
 
3857
4221
  // store the result for all queries in local memory (the O matrix from the paper)
3858
4222
  o4_t lo[DV4/NL];
@@ -4157,7 +4521,7 @@ kernel void kernel_flash_attn_ext_vec(
4157
4521
  half4, \
4158
4522
  float, \
4159
4523
  float, float4, \
4160
- half4
4524
+ float4
4161
4525
 
4162
4526
  typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
4163
4527
 
@@ -4271,11 +4635,16 @@ kernel void kernel_cpy(
4271
4635
  device const char * src0,
4272
4636
  device char * dst,
4273
4637
  uint3 tgpig[[threadgroup_position_in_grid]],
4638
+ uint tiitg[[thread_index_in_threadgroup]],
4274
4639
  ushort3 tpitg[[thread_position_in_threadgroup]],
4275
- ushort3 ntg[[threads_per_threadgroup]]) {
4640
+ ushort3 tptg[[threads_per_threadgroup]]) {
4276
4641
  const int i03 = tgpig[2];
4277
4642
  const int i02 = tgpig[1];
4278
- const int i01 = tgpig[0];
4643
+ const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x;
4644
+
4645
+ if (i01 >= args.ne01) {
4646
+ return;
4647
+ }
4279
4648
 
4280
4649
  const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
4281
4650
 
@@ -4286,7 +4655,7 @@ kernel void kernel_cpy(
4286
4655
 
4287
4656
  device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
4288
4657
 
4289
- for (int64_t i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
4658
+ for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) {
4290
4659
  device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4291
4660
  dst_data[i00] = (T1) src[0];
4292
4661
  }
@@ -4306,6 +4675,7 @@ template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bf
4306
4675
  template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
4307
4676
  #endif
4308
4677
 
4678
+ // TODO: templetify these kernels
4309
4679
  kernel void kernel_cpy_f32_q8_0(
4310
4680
  constant ggml_metal_kargs_cpy & args,
4311
4681
  device const char * src0,
@@ -4329,23 +4699,7 @@ kernel void kernel_cpy_f32_q8_0(
4329
4699
  for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
4330
4700
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4331
4701
 
4332
- float amax = 0.0f; // absolute max
4333
-
4334
- for (int j = 0; j < QK8_0; j++) {
4335
- const float v = src[j];
4336
- amax = MAX(amax, fabs(v));
4337
- }
4338
-
4339
- const float d = amax / ((1 << 7) - 1);
4340
- const float id = d ? 1.0f/d : 0.0f;
4341
-
4342
- dst_data[i00/QK8_0].d = d;
4343
-
4344
- for (int j = 0; j < QK8_0; ++j) {
4345
- const float x0 = src[j]*id;
4346
-
4347
- dst_data[i00/QK8_0].qs[j] = round(x0);
4348
- }
4702
+ quantize_q8_0(src, dst_data[i00/QK8_0]);
4349
4703
  }
4350
4704
  }
4351
4705
 
@@ -4372,32 +4726,7 @@ kernel void kernel_cpy_f32_q4_0(
4372
4726
  for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
4373
4727
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4374
4728
 
4375
- float amax = 0.0f; // absolute max
4376
- float max = 0.0f;
4377
-
4378
- for (int j = 0; j < QK4_0; j++) {
4379
- const float v = src[j];
4380
- if (amax < fabs(v)) {
4381
- amax = fabs(v);
4382
- max = v;
4383
- }
4384
- }
4385
-
4386
- const float d = max / -8;
4387
- const float id = d ? 1.0f/d : 0.0f;
4388
-
4389
- dst_data[i00/QK4_0].d = d;
4390
-
4391
- for (int j = 0; j < QK4_0/2; ++j) {
4392
- const float x0 = src[0 + j]*id;
4393
- const float x1 = src[QK4_0/2 + j]*id;
4394
-
4395
- const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
4396
- const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
4397
-
4398
- dst_data[i00/QK4_0].qs[j] = xi0;
4399
- dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
4400
- }
4729
+ quantize_q4_0(src, dst_data[i00/QK4_0]);
4401
4730
  }
4402
4731
  }
4403
4732
 
@@ -4424,31 +4753,7 @@ kernel void kernel_cpy_f32_q4_1(
4424
4753
  for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
4425
4754
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4426
4755
 
4427
- float min = FLT_MAX;
4428
- float max = -FLT_MAX;
4429
-
4430
- for (int j = 0; j < QK4_1; j++) {
4431
- const float v = src[j];
4432
- if (min > v) min = v;
4433
- if (max < v) max = v;
4434
- }
4435
-
4436
- const float d = (max - min) / ((1 << 4) - 1);
4437
- const float id = d ? 1.0f/d : 0.0f;
4438
-
4439
- dst_data[i00/QK4_1].d = d;
4440
- dst_data[i00/QK4_1].m = min;
4441
-
4442
- for (int j = 0; j < QK4_1/2; ++j) {
4443
- const float x0 = (src[0 + j] - min)*id;
4444
- const float x1 = (src[QK4_1/2 + j] - min)*id;
4445
-
4446
- const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
4447
- const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
4448
-
4449
- dst_data[i00/QK4_1].qs[j] = xi0;
4450
- dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
4451
- }
4756
+ quantize_q4_1(src, dst_data[i00/QK4_1]);
4452
4757
  }
4453
4758
  }
4454
4759
 
@@ -4475,38 +4780,7 @@ kernel void kernel_cpy_f32_q5_0(
4475
4780
  for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
4476
4781
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4477
4782
 
4478
- float amax = 0.0f; // absolute max
4479
- float max = 0.0f;
4480
-
4481
- for (int j = 0; j < QK5_0; j++) {
4482
- const float v = src[j];
4483
- if (amax < fabs(v)) {
4484
- amax = fabs(v);
4485
- max = v;
4486
- }
4487
- }
4488
-
4489
- const float d = max / -16;
4490
- const float id = d ? 1.0f/d : 0.0f;
4491
-
4492
- dst_data[i00/QK5_0].d = d;
4493
-
4494
- uint32_t qh = 0;
4495
- for (int j = 0; j < QK5_0/2; ++j) {
4496
- const float x0 = src[0 + j]*id;
4497
- const float x1 = src[QK5_0/2 + j]*id;
4498
-
4499
- const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
4500
- const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
4501
-
4502
- dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
4503
- qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
4504
- qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
4505
- }
4506
- thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
4507
- for (int j = 0; j < 4; ++j) {
4508
- dst_data[i00/QK5_0].qh[j] = qh8[j];
4509
- }
4783
+ quantize_q5_0(src, dst_data[i00/QK5_0]);
4510
4784
  }
4511
4785
  }
4512
4786
 
@@ -4533,49 +4807,8 @@ kernel void kernel_cpy_f32_q5_1(
4533
4807
  for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
4534
4808
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4535
4809
 
4536
- float max = src[0];
4537
- float min = src[0];
4538
-
4539
- for (int j = 1; j < QK5_1; j++) {
4540
- const float v = src[j];
4541
- min = v < min ? v : min;
4542
- max = v > max ? v : max;
4543
- }
4544
-
4545
- const float d = (max - min) / 31;
4546
- const float id = d ? 1.0f/d : 0.0f;
4547
-
4548
- dst_data[i00/QK5_1].d = d;
4549
- dst_data[i00/QK5_1].m = min;
4550
-
4551
- uint32_t qh = 0;
4552
- for (int j = 0; j < QK5_1/2; ++j) {
4553
- const float x0 = (src[0 + j] - min)*id;
4554
- const float x1 = (src[QK5_1/2 + j] - min)*id;
4555
-
4556
- const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
4557
- const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
4558
-
4559
- dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
4560
- qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
4561
- qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
4562
- }
4563
- thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
4564
- for (int j = 0; j < 4; ++j) {
4565
- dst_data[i00/QK5_1].qh[j] = qh8[j];
4566
- }
4567
- }
4568
- }
4569
-
4570
- static inline int best_index_int8(int n, constant float * val, float x) {
4571
- if (x <= val[0]) return 0;
4572
- if (x >= val[n-1]) return n-1;
4573
- int ml = 0, mu = n-1;
4574
- while (mu-ml > 1) {
4575
- int mav = (ml+mu)/2;
4576
- if (x < val[mav]) mu = mav; else ml = mav;
4810
+ quantize_q5_1(src, dst_data[i00/QK5_1]);
4577
4811
  }
4578
- return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
4579
4812
  }
4580
4813
 
4581
4814
  kernel void kernel_cpy_f32_iq4_nl(
@@ -4601,40 +4834,7 @@ kernel void kernel_cpy_f32_iq4_nl(
4601
4834
  for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
4602
4835
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4603
4836
 
4604
- float amax = 0.0f; // absolute max
4605
- float max = 0.0f;
4606
-
4607
- for (int j = 0; j < QK4_NL; j++) {
4608
- const float v = src[j];
4609
- if (amax < fabs(v)) {
4610
- amax = fabs(v);
4611
- max = v;
4612
- }
4613
- }
4614
-
4615
- const float d = max / kvalues_iq4nl_f[0];
4616
- const float id = d ? 1.0f/d : 0.0f;
4617
-
4618
- float sumqx = 0, sumq2 = 0;
4619
- for (int j = 0; j < QK4_NL/2; ++j) {
4620
- const float x0 = src[0 + j]*id;
4621
- const float x1 = src[QK4_NL/2 + j]*id;
4622
-
4623
- const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
4624
- const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
4625
-
4626
- dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
4627
-
4628
- const float v0 = kvalues_iq4nl_f[xi0];
4629
- const float v1 = kvalues_iq4nl_f[xi1];
4630
- const float w0 = src[0 + j]*src[0 + j];
4631
- const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
4632
- sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
4633
- sumq2 += w0*v0*v0 + w1*v1*v1;
4634
-
4635
- }
4636
-
4637
- dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
4837
+ quantize_iq4_nl(src, dst_data[i00/QK4_NL]);
4638
4838
  }
4639
4839
  }
4640
4840
 
@@ -6315,10 +6515,10 @@ kernel void kernel_mul_mv_iq4_xs_f32(
6315
6515
 
6316
6516
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
6317
6517
  kernel void kernel_get_rows_q(
6518
+ constant ggml_metal_kargs_get_rows & args,
6318
6519
  device const void * src0,
6319
6520
  device const void * src1,
6320
6521
  device float * dst,
6321
- constant ggml_metal_kargs_get_rows & args,
6322
6522
  uint3 tgpig[[threadgroup_position_in_grid]],
6323
6523
  uint tiitg[[thread_index_in_threadgroup]],
6324
6524
  uint3 tptg [[threads_per_threadgroup]]) {
@@ -6338,10 +6538,10 @@ kernel void kernel_get_rows_q(
6338
6538
 
6339
6539
  template<typename T>
6340
6540
  kernel void kernel_get_rows_f(
6541
+ constant ggml_metal_kargs_get_rows & args,
6341
6542
  device const void * src0,
6342
6543
  device const void * src1,
6343
6544
  device float * dst,
6344
- constant ggml_metal_kargs_get_rows & args,
6345
6545
  uint3 tgpig[[threadgroup_position_in_grid]],
6346
6546
  uint tiitg[[thread_index_in_threadgroup]],
6347
6547
  uint3 tptg [[threads_per_threadgroup]]) {
@@ -6359,10 +6559,10 @@ kernel void kernel_get_rows_f(
6359
6559
  }
6360
6560
 
6361
6561
  kernel void kernel_get_rows_i32(
6562
+ constant ggml_metal_kargs_get_rows & args,
6362
6563
  device const void * src0,
6363
6564
  device const void * src1,
6364
6565
  device int32_t * dst,
6365
- constant ggml_metal_kargs_get_rows & args,
6366
6566
  uint3 tgpig[[threadgroup_position_in_grid]],
6367
6567
  uint tiitg[[thread_index_in_threadgroup]],
6368
6568
  uint3 tptg [[threads_per_threadgroup]]) {
@@ -6379,6 +6579,67 @@ kernel void kernel_get_rows_i32(
6379
6579
  }
6380
6580
  }
6381
6581
 
6582
+ template<typename block_q, void (*quantize_func)(device const float *, device block_q &)>
6583
+ kernel void kernel_set_rows_q32(
6584
+ constant ggml_metal_kargs_set_rows & args,
6585
+ device const void * src0,
6586
+ device const void * src1,
6587
+ device float * dst,
6588
+ uint3 tgpig[[threadgroup_position_in_grid]],
6589
+ uint tiitg[[thread_index_in_threadgroup]],
6590
+ uint3 tptg [[threads_per_threadgroup]]) {
6591
+ const int32_t i03 = tgpig.z;
6592
+ const int32_t i02 = tgpig.y;
6593
+
6594
+ const int32_t i12 = i03%args.ne12;
6595
+ const int32_t i11 = i02%args.ne11;
6596
+
6597
+ const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
6598
+ if (i01 >= args.ne01) {
6599
+ return;
6600
+ }
6601
+
6602
+ const int32_t i10 = i01;
6603
+ const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
6604
+
6605
+ device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
6606
+ const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
6607
+
6608
+ for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
6609
+ quantize_func(src_row + 32*ind, dst_row[ind]);
6610
+ }
6611
+ }
6612
+
6613
+ template<typename T>
6614
+ kernel void kernel_set_rows_f(
6615
+ constant ggml_metal_kargs_set_rows & args,
6616
+ device const void * src0,
6617
+ device const void * src1,
6618
+ device float * dst,
6619
+ uint3 tgpig[[threadgroup_position_in_grid]],
6620
+ uint tiitg[[thread_index_in_threadgroup]],
6621
+ uint3 tptg [[threads_per_threadgroup]]) {
6622
+ const int32_t i03 = tgpig.z;
6623
+ const int32_t i02 = tgpig.y;
6624
+
6625
+ const int32_t i12 = i03%args.ne12;
6626
+ const int32_t i11 = i02%args.ne11;
6627
+
6628
+ const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
6629
+ if (i01 >= args.ne01) {
6630
+ return;
6631
+ }
6632
+
6633
+ const int32_t i10 = i01;
6634
+ const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
6635
+
6636
+ device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
6637
+ const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
6638
+
6639
+ for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
6640
+ dst_row[ind] = (T) src_row[ind];
6641
+ }
6642
+ }
6382
6643
 
6383
6644
  #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
6384
6645
  #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
@@ -6802,6 +7063,27 @@ template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get
6802
7063
  template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
6803
7064
  template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6804
7065
 
7066
+ //
7067
+ // set rows
7068
+ //
7069
+
7070
+ typedef decltype(kernel_set_rows_f<float>) set_rows_f_t;
7071
+
7072
+ template [[host_name("kernel_set_rows_f32")]] kernel set_rows_f_t kernel_set_rows_f<float>;
7073
+ template [[host_name("kernel_set_rows_f16")]] kernel set_rows_f_t kernel_set_rows_f<half>;
7074
+ #if defined(GGML_METAL_USE_BF16)
7075
+ template [[host_name("kernel_set_rows_bf16")]] kernel set_rows_f_t kernel_set_rows_f<bfloat>;
7076
+ #endif
7077
+
7078
+ typedef decltype(kernel_set_rows_q32<block_q8_0, quantize_q8_0>) set_rows_q32_t;
7079
+
7080
+ template [[host_name("kernel_set_rows_q8_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q8_0, quantize_q8_0>;
7081
+ template [[host_name("kernel_set_rows_q4_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_0, quantize_q4_0>;
7082
+ template [[host_name("kernel_set_rows_q4_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_1, quantize_q4_1>;
7083
+ template [[host_name("kernel_set_rows_q5_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_0, quantize_q5_0>;
7084
+ template [[host_name("kernel_set_rows_q5_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_1, quantize_q5_1>;
7085
+ template [[host_name("kernel_set_rows_iq4_nl")]] kernel set_rows_q32_t kernel_set_rows_q32<block_iq4_nl, quantize_iq4_nl>;
7086
+
6805
7087
  //
6806
7088
  // matrix-matrix multiplication
6807
7089
  //