whispercpp 1.3.3 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (586) hide show
  1. checksums.yaml +4 -4
  2. data/ext/ruby_whisper_params.c +55 -25
  3. data/ext/sources/CMakeLists.txt +1 -1
  4. data/ext/sources/bindings/javascript/package.json +1 -1
  5. data/ext/sources/build-xcframework.sh +24 -0
  6. data/ext/sources/examples/CMakeLists.txt +1 -0
  7. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  8. data/ext/sources/examples/addon.node/index.js +7 -5
  9. data/ext/sources/examples/bench/bench.cpp +26 -16
  10. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  11. data/ext/sources/examples/cli/cli.cpp +4 -2
  12. data/ext/sources/examples/command/command.cpp +26 -24
  13. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  14. data/ext/sources/examples/common-ggml.cpp +2 -0
  15. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  16. data/ext/sources/examples/server/server.cpp +24 -13
  17. data/ext/sources/examples/server.py +6 -1
  18. data/ext/sources/examples/stream/stream.cpp +4 -2
  19. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  20. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  21. data/ext/sources/examples/talk-llama/CMakeLists.txt +2 -2
  22. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  23. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  24. data/ext/sources/examples/talk-llama/llama-arch.cpp +588 -15
  25. data/ext/sources/examples/talk-llama/llama-arch.h +58 -1
  26. data/ext/sources/examples/talk-llama/llama-batch.cpp +103 -71
  27. data/ext/sources/examples/talk-llama/llama-batch.h +31 -18
  28. data/ext/sources/examples/talk-llama/llama-chat.cpp +120 -5
  29. data/ext/sources/examples/talk-llama/llama-chat.h +7 -0
  30. data/ext/sources/examples/talk-llama/llama-context.cpp +460 -357
  31. data/ext/sources/examples/talk-llama/llama-context.h +44 -29
  32. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  33. data/ext/sources/examples/talk-llama/llama-graph.cpp +543 -271
  34. data/ext/sources/examples/talk-llama/llama-graph.h +278 -168
  35. data/ext/sources/examples/talk-llama/llama-hparams.cpp +118 -4
  36. data/ext/sources/examples/talk-llama/llama-hparams.h +61 -15
  37. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  38. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  39. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  40. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2020 -0
  41. data/ext/sources/examples/talk-llama/llama-kv-cache.h +358 -27
  42. data/ext/sources/examples/talk-llama/llama-kv-cells.h +80 -28
  43. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +56 -36
  44. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  45. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +48 -19
  46. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +13 -14
  47. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  48. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +2 -0
  49. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  50. data/ext/sources/examples/talk-llama/llama-model.cpp +7165 -2336
  51. data/ext/sources/examples/talk-llama/llama-model.h +60 -9
  52. data/ext/sources/examples/talk-llama/llama-quant.cpp +48 -10
  53. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  54. data/ext/sources/examples/talk-llama/llama-vocab.cpp +440 -13
  55. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -0
  56. data/ext/sources/examples/talk-llama/llama.cpp +65 -10
  57. data/ext/sources/examples/talk-llama/llama.h +95 -177
  58. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  59. data/ext/sources/examples/talk-llama/unicode.cpp +207 -0
  60. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  61. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  62. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  63. data/ext/sources/ggml/CMakeLists.txt +59 -31
  64. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  65. data/ext/sources/ggml/include/ggml-backend.h +17 -1
  66. data/ext/sources/ggml/include/ggml-cpu.h +1 -1
  67. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  68. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  69. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  70. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  71. data/ext/sources/ggml/include/ggml.h +221 -16
  72. data/ext/sources/ggml/src/CMakeLists.txt +17 -2
  73. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  74. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +30 -13
  76. data/ext/sources/ggml/src/ggml-backend.cpp +221 -38
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  79. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  83. data/ext/sources/ggml/src/ggml-cann/common.h +143 -1
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +488 -69
  85. data/ext/sources/ggml/src/ggml-common.h +17 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +40 -18
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +4 -2
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
  90. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +103 -582
  91. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  92. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +265 -437
  93. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  94. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  95. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  96. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  97. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  98. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +32 -2
  99. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  100. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -6
  101. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +70 -42
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +35 -28
  103. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  104. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  105. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +227 -97
  106. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +474 -1116
  107. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1587 -1177
  108. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -8
  109. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  110. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +458 -47
  112. data/ext/sources/ggml/src/ggml-cpu/repack.h +22 -0
  113. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +89 -60
  114. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  115. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  116. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  117. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  118. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  119. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  120. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +170 -26
  121. data/ext/sources/ggml/src/ggml-cpu/vec.h +506 -63
  122. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  123. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  124. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  125. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  126. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  127. data/ext/sources/ggml/src/ggml-cuda/common.cuh +250 -63
  128. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  129. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  130. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  131. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  132. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +15 -0
  133. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  134. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  135. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  136. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  137. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +498 -367
  138. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +137 -91
  139. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  140. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  141. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  142. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +86 -50
  143. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  144. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  145. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  146. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +379 -107
  147. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  148. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  149. data/ext/sources/ggml/src/ggml-cuda/mean.cu +56 -2
  150. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  151. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  152. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  153. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  154. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  155. data/ext/sources/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
  156. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
  157. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  158. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  159. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  160. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  161. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  162. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  163. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  164. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  165. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  166. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  167. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  168. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  169. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  170. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  171. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  172. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  173. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  174. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  176. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  177. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -100
  178. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  179. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  191. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  192. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  193. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  234. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  235. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  236. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  237. data/ext/sources/ggml/src/ggml-cuda/unary.cu +90 -0
  238. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +8 -0
  239. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  240. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  241. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  242. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  243. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  244. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +10 -2
  245. data/ext/sources/ggml/src/ggml-impl.h +119 -9
  246. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  247. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  248. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  249. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  250. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  251. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  252. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  253. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  254. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +136 -63
  255. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +2854 -1503
  259. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  260. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +18 -0
  261. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +2510 -242
  262. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  263. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  264. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  265. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  266. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  300. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  301. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  302. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  303. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +67 -47
  304. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  305. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +15 -5
  306. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  307. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +25 -16
  308. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  309. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +166 -99
  310. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -306
  311. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  312. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  313. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +1 -31
  314. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +79 -29
  315. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  316. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  317. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  318. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +328 -323
  319. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  320. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  321. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  322. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +201 -132
  323. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +74 -55
  324. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  325. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  326. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +35 -42
  327. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  328. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  329. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  330. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  331. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  332. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  333. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3492 -883
  334. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  335. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  336. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  337. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  338. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  339. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  340. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  341. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  342. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  343. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  344. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  345. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  346. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  347. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  348. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  349. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  350. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  351. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  352. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  353. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  354. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  355. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  356. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  357. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  358. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  359. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  360. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  361. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  362. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  363. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  364. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  365. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +4 -0
  366. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  367. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  368. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  369. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  370. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  371. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  372. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  373. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  374. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  375. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +55 -11
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -77
  401. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  402. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  403. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  404. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  405. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  406. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  407. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  408. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  409. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  410. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  411. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  412. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  413. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  414. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  415. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  416. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  417. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  418. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  419. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  420. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  421. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  422. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  423. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  424. data/ext/sources/ggml/src/ggml.c +478 -98
  425. data/ext/sources/ggml/src/gguf.cpp +8 -1
  426. data/ext/sources/src/whisper.cpp +23 -46
  427. data/ext/sources/tests/CMakeLists.txt +8 -1
  428. data/ext/sources/tests/test-vad-full.cpp +3 -3
  429. data/ext/sources/tests/test-vad.cpp +2 -2
  430. data/lib/whisper/model/uri.rb +1 -1
  431. data/sig/whisper.rbs +7 -0
  432. data/test/test_params.rb +8 -0
  433. data/test/test_whisper.rb +1 -1
  434. data/whispercpp.gemspec +1 -1
  435. metadata +164 -157
  436. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  437. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  438. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  439. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  440. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  441. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  442. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  443. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  444. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  445. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  446. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  447. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  448. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  449. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  450. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  451. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  452. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  453. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  454. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  455. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  456. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  457. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  458. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  459. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  460. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  461. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  462. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  463. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  464. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  465. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  466. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  467. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  468. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  469. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  470. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  471. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  472. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  473. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  474. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  475. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  476. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  477. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  478. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  479. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  480. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  481. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  482. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  483. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  484. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  485. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  486. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  487. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  488. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  489. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  490. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  491. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  492. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  493. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  494. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  495. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  496. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  497. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  498. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  499. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  500. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  501. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  502. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  503. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  504. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  505. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  506. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  507. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  508. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  509. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  510. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  511. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  512. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  513. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  514. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  515. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  516. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  517. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  518. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  519. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  520. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  521. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  522. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  523. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  524. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  525. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  526. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  527. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  548. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  549. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  550. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  551. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  552. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  553. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  554. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  555. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  556. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  557. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  558. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  559. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  560. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  561. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  562. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  563. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  564. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  565. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  566. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  567. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  568. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  569. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  570. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  571. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  572. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  573. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  574. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  575. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  576. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  577. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  578. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  579. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  580. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  581. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  582. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  583. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  584. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  585. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  586. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
@@ -26,6 +26,7 @@ enum llm_arch {
26
26
  LLM_ARCH_NOMIC_BERT_MOE,
27
27
  LLM_ARCH_NEO_BERT,
28
28
  LLM_ARCH_JINA_BERT_V2,
29
+ LLM_ARCH_JINA_BERT_V3,
29
30
  LLM_ARCH_BLOOM,
30
31
  LLM_ARCH_STABLELM,
31
32
  LLM_ARCH_QWEN,
@@ -38,6 +39,7 @@ enum llm_arch {
38
39
  LLM_ARCH_PHI3,
39
40
  LLM_ARCH_PHIMOE,
40
41
  LLM_ARCH_PLAMO,
42
+ LLM_ARCH_PLAMO2,
41
43
  LLM_ARCH_CODESHELL,
42
44
  LLM_ARCH_ORION,
43
45
  LLM_ARCH_INTERNLM2,
@@ -47,8 +49,12 @@ enum llm_arch {
47
49
  LLM_ARCH_GEMMA2,
48
50
  LLM_ARCH_GEMMA3,
49
51
  LLM_ARCH_GEMMA3N,
52
+ LLM_ARCH_GEMMA_EMBEDDING,
50
53
  LLM_ARCH_STARCODER2,
51
54
  LLM_ARCH_MAMBA,
55
+ LLM_ARCH_MAMBA2,
56
+ LLM_ARCH_JAMBA,
57
+ LLM_ARCH_FALCON_H1,
52
58
  LLM_ARCH_XVERSE,
53
59
  LLM_ARCH_COMMAND_R,
54
60
  LLM_ARCH_COHERE2,
@@ -62,18 +68,22 @@ enum llm_arch {
62
68
  LLM_ARCH_DEEPSEEK2,
63
69
  LLM_ARCH_CHATGLM,
64
70
  LLM_ARCH_GLM4,
71
+ LLM_ARCH_GLM4_MOE,
65
72
  LLM_ARCH_BITNET,
66
73
  LLM_ARCH_T5,
67
74
  LLM_ARCH_T5ENCODER,
68
75
  LLM_ARCH_JAIS,
69
76
  LLM_ARCH_NEMOTRON,
77
+ LLM_ARCH_NEMOTRON_H,
70
78
  LLM_ARCH_EXAONE,
79
+ LLM_ARCH_EXAONE4,
71
80
  LLM_ARCH_RWKV6,
72
81
  LLM_ARCH_RWKV6QWEN2,
73
82
  LLM_ARCH_RWKV7,
74
83
  LLM_ARCH_ARWKV7,
75
84
  LLM_ARCH_GRANITE,
76
85
  LLM_ARCH_GRANITE_MOE,
86
+ LLM_ARCH_GRANITE_HYBRID,
77
87
  LLM_ARCH_CHAMELEON,
78
88
  LLM_ARCH_WAVTOKENIZER_DEC,
79
89
  LLM_ARCH_PLM,
@@ -81,6 +91,18 @@ enum llm_arch {
81
91
  LLM_ARCH_DOTS1,
82
92
  LLM_ARCH_ARCEE,
83
93
  LLM_ARCH_ERNIE4_5,
94
+ LLM_ARCH_ERNIE4_5_MOE,
95
+ LLM_ARCH_HUNYUAN_MOE,
96
+ LLM_ARCH_HUNYUAN_DENSE,
97
+ LLM_ARCH_SMOLLM3,
98
+ LLM_ARCH_OPENAI_MOE,
99
+ LLM_ARCH_LFM2,
100
+ LLM_ARCH_DREAM,
101
+ LLM_ARCH_SMALLTHINKER,
102
+ LLM_ARCH_LLADA,
103
+ LLM_ARCH_LLADA_MOE,
104
+ LLM_ARCH_SEED_OSS,
105
+ LLM_ARCH_GROVEMOE,
84
106
  LLM_ARCH_UNKNOWN,
85
107
  };
86
108
 
@@ -108,6 +130,7 @@ enum llm_kv {
108
130
  LLM_KV_FEED_FORWARD_LENGTH,
109
131
  LLM_KV_EXPERT_FEED_FORWARD_LENGTH,
110
132
  LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH,
133
+ LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH,
111
134
  LLM_KV_USE_PARALLEL_RESIDUAL,
112
135
  LLM_KV_TENSOR_DATA_LAYOUT,
113
136
  LLM_KV_EXPERT_COUNT,
@@ -116,11 +139,16 @@ enum llm_kv {
116
139
  LLM_KV_EXPERT_WEIGHTS_SCALE,
117
140
  LLM_KV_EXPERT_WEIGHTS_NORM,
118
141
  LLM_KV_EXPERT_GATING_FUNC,
142
+ LLM_KV_EXPERT_GROUP_SCALE,
143
+ LLM_KV_EXPERTS_PER_GROUP,
119
144
  LLM_KV_MOE_EVERY_N_LAYERS,
145
+ LLM_KV_NEXTN_PREDICT_LAYERS,
120
146
  LLM_KV_POOLING_TYPE,
121
147
  LLM_KV_LOGIT_SCALE,
122
148
  LLM_KV_DECODER_START_TOKEN_ID,
149
+ LLM_KV_DECODER_BLOCK_COUNT,
123
150
  LLM_KV_ATTN_LOGIT_SOFTCAPPING,
151
+ LLM_KV_ROUTER_LOGIT_SOFTCAPPING,
124
152
  LLM_KV_FINAL_LOGIT_SOFTCAPPING,
125
153
  LLM_KV_SWIN_NORM,
126
154
  LLM_KV_RESCALE_EVERY_N_LAYERS,
@@ -151,9 +179,10 @@ enum llm_kv {
151
179
  LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
152
180
  LLM_KV_ATTENTION_SLIDING_WINDOW,
153
181
  LLM_KV_ATTENTION_SCALE,
182
+ LLM_KV_ATTENTION_OUTPUT_SCALE,
183
+ LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
154
184
  LLM_KV_ATTENTION_KEY_LENGTH_MLA,
155
185
  LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
156
- LLM_KV_ATTENTION_LAYER_INDICES,
157
186
 
158
187
  LLM_KV_ROPE_DIMENSION_COUNT,
159
188
  LLM_KV_ROPE_DIMENSION_SECTIONS,
@@ -165,6 +194,10 @@ enum llm_kv {
165
194
  LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
166
195
  LLM_KV_ROPE_SCALING_FINETUNED,
167
196
  LLM_KV_ROPE_SCALING_YARN_LOG_MUL,
197
+ LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR,
198
+ LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR,
199
+ LLM_KV_ROPE_SCALING_YARN_BETA_FAST,
200
+ LLM_KV_ROPE_SCALING_YARN_BETA_SLOW,
168
201
 
169
202
  LLM_KV_SPLIT_NO,
170
203
  LLM_KV_SPLIT_COUNT,
@@ -174,6 +207,7 @@ enum llm_kv {
174
207
  LLM_KV_SSM_CONV_KERNEL,
175
208
  LLM_KV_SSM_STATE_SIZE,
176
209
  LLM_KV_SSM_TIME_STEP_RANK,
210
+ LLM_KV_SSM_GROUP_COUNT,
177
211
  LLM_KV_SSM_DT_B_C_RMS,
178
212
 
179
213
  LLM_KV_WKV_HEAD_SIZE,
@@ -212,6 +246,9 @@ enum llm_kv {
212
246
 
213
247
  LLM_KV_ADAPTER_TYPE,
214
248
  LLM_KV_ADAPTER_LORA_ALPHA,
249
+ LLM_KV_ADAPTER_LORA_TASK_NAME,
250
+ LLM_KV_ADAPTER_LORA_PROMPT_PREFIX,
251
+ LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS,
215
252
 
216
253
  LLM_KV_POSNET_EMBEDDING_LENGTH,
217
254
  LLM_KV_POSNET_BLOCK_COUNT,
@@ -221,6 +258,8 @@ enum llm_kv {
221
258
 
222
259
  LLM_KV_CLASSIFIER_OUTPUT_LABELS,
223
260
 
261
+ LLM_KV_SHORTCONV_L_CACHE,
262
+
224
263
  // deprecated:
225
264
  LLM_KV_TOKENIZER_PREFIX_ID,
226
265
  LLM_KV_TOKENIZER_SUFFIX_ID,
@@ -247,6 +286,7 @@ enum llm_tensor {
247
286
  LLM_TENSOR_ATTN_OUT_NORM,
248
287
  LLM_TENSOR_ATTN_POST_NORM,
249
288
  LLM_TENSOR_ATTN_ROT_EMBD,
289
+ LLM_TENSOR_ATTN_SINKS,
250
290
  LLM_TENSOR_FFN_GATE_INP,
251
291
  LLM_TENSOR_FFN_GATE_INP_SHEXP,
252
292
  LLM_TENSOR_FFN_NORM,
@@ -265,6 +305,9 @@ enum llm_tensor {
265
305
  LLM_TENSOR_FFN_DOWN_SHEXP,
266
306
  LLM_TENSOR_FFN_GATE_SHEXP,
267
307
  LLM_TENSOR_FFN_UP_SHEXP,
308
+ LLM_TENSOR_FFN_DOWN_CHEXPS,
309
+ LLM_TENSOR_FFN_GATE_CHEXPS,
310
+ LLM_TENSOR_FFN_UP_CHEXPS,
268
311
  LLM_TENSOR_FFN_EXP_PROBS_B,
269
312
  LLM_TENSOR_ATTN_Q_NORM,
270
313
  LLM_TENSOR_ATTN_K_NORM,
@@ -291,8 +334,12 @@ enum llm_tensor {
291
334
  LLM_TENSOR_SSM_CONV1D,
292
335
  LLM_TENSOR_SSM_X,
293
336
  LLM_TENSOR_SSM_DT,
337
+ LLM_TENSOR_SSM_DT_NORM,
294
338
  LLM_TENSOR_SSM_A,
339
+ LLM_TENSOR_SSM_B_NORM,
340
+ LLM_TENSOR_SSM_C_NORM,
295
341
  LLM_TENSOR_SSM_D,
342
+ LLM_TENSOR_SSM_NORM,
296
343
  LLM_TENSOR_SSM_OUT,
297
344
  LLM_TENSOR_TIME_MIX_W0,
298
345
  LLM_TENSOR_TIME_MIX_W1,
@@ -386,6 +433,15 @@ enum llm_tensor {
386
433
  LLM_TENSOR_POS_NET_ATTN_K,
387
434
  LLM_TENSOR_POS_NET_ATTN_V,
388
435
  LLM_TENSOR_POS_NET_ATTN_OUT,
436
+ LLM_TENSOR_SHORTCONV_CONV,
437
+ LLM_TENSOR_SHORTCONV_INPROJ,
438
+ LLM_TENSOR_SHORTCONV_OUTPROJ,
439
+ LLM_TENSOR_NEXTN_EH_PROJ,
440
+ LLM_TENSOR_NEXTN_EMBED_TOKENS,
441
+ LLM_TENSOR_NEXTN_ENORM,
442
+ LLM_TENSOR_NEXTN_HNORM,
443
+ LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
444
+ LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
389
445
  };
390
446
 
391
447
  enum llm_tensor_layer {
@@ -462,3 +518,4 @@ const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
462
518
 
463
519
  bool llm_arch_is_recurrent(const llm_arch & arch);
464
520
  bool llm_arch_is_hybrid (const llm_arch & arch);
521
+ bool llm_arch_is_diffusion(const llm_arch & arch);
@@ -27,6 +27,7 @@ bool llama_batch_allocr::init(
27
27
  const llama_vocab & vocab,
28
28
  const llama_memory_i * memory,
29
29
  uint32_t n_embd,
30
+ uint32_t n_seq_max,
30
31
  bool output_all) {
31
32
  clear();
32
33
 
@@ -40,6 +41,11 @@ bool llama_batch_allocr::init(
40
41
  // validate input batch
41
42
  //
42
43
 
44
+ if (n_seq_max > LLAMA_MAX_SEQ) {
45
+ LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ);
46
+ return false;
47
+ }
48
+
43
49
  if (batch.token) {
44
50
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
45
51
  if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
@@ -52,8 +58,8 @@ bool llama_batch_allocr::init(
52
58
  if (batch.seq_id) {
53
59
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
54
60
  for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
55
- if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
56
- LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
61
+ if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) {
62
+ LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d >= %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
57
63
  return false;
58
64
  }
59
65
  }
@@ -86,7 +92,7 @@ bool llama_batch_allocr::init(
86
92
 
87
93
  // initialize the starting position for each sequence based on the positions in the memory
88
94
  llama_pos p0[LLAMA_MAX_SEQ];
89
- for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
95
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
90
96
  if (!memory) {
91
97
  // if no memory -> start from 0
92
98
  p0[s] = 0;
@@ -143,13 +149,16 @@ bool llama_batch_allocr::init(
143
149
  // compute stats
144
150
  //
145
151
 
146
- this->n_embd = n_embd;
152
+ this->n_embd = n_embd;
153
+ this->n_seq_max = n_seq_max;
147
154
 
148
155
  // count the outputs in this batch
149
156
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
150
157
  n_outputs += batch.logits[i] != 0;
151
158
  }
152
159
 
160
+ has_cpl = false;
161
+
153
162
  // determine coupled sequences
154
163
  // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
155
164
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
@@ -166,6 +175,8 @@ bool llama_batch_allocr::init(
166
175
 
167
176
  // note: tracking the other way around is not necessary for now
168
177
  //seq_cpl[s0][s1] = true;
178
+
179
+ has_cpl = true;
169
180
  }
170
181
  }
171
182
  }
@@ -187,7 +198,7 @@ bool llama_batch_allocr::init(
187
198
  seq_set_map[cur].push_back(i);
188
199
  }
189
200
 
190
- for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
201
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
191
202
  if (seq_set_unq.test(s)) {
192
203
  seq_idx[s] = seq_id_unq.size();
193
204
  seq_id_unq.push_back(s);
@@ -199,7 +210,7 @@ bool llama_batch_allocr::init(
199
210
  LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
200
211
 
201
212
  llama_ubatch ubatch {
202
- /*.equal_seqs =*/ false,
213
+ /*.b_equal_seqs =*/ false,
203
214
  /*.n_tokens =*/ (uint32_t) batch.n_tokens,
204
215
  /*.n_seq_tokens =*/ (uint32_t) 1,
205
216
  /*.n_seqs =*/ (uint32_t) batch.n_tokens,
@@ -212,6 +223,7 @@ bool llama_batch_allocr::init(
212
223
  /*.seq_id_unq =*/ this->seq_id_unq.data(),
213
224
  /*.seq_idx =*/ this->seq_idx.data(),
214
225
  /*.output =*/ batch.logits,
226
+ /*.data =*/ {},
215
227
  };
216
228
 
217
229
  ubatch_print(ubatch, debug);
@@ -239,7 +251,7 @@ bool llama_batch_allocr::init(
239
251
  // consistency checks
240
252
  //
241
253
 
242
- for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
254
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
243
255
  if (seq_pos[s].empty()) {
244
256
  continue;
245
257
  }
@@ -282,8 +294,8 @@ bool llama_batch_allocr::init(
282
294
  }
283
295
 
284
296
  if (memory) {
285
- for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
286
- for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
297
+ for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) {
298
+ for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) {
287
299
  if (seq_cpl[s0][s1]) {
288
300
  if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
289
301
  memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
@@ -314,12 +326,12 @@ bool llama_batch_allocr::init(
314
326
  //
315
327
  {
316
328
  seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
317
- for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
329
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
318
330
  cur_seq_set[s].set();
319
331
  }
320
332
 
321
333
  llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
322
- for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
334
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
323
335
  cur_seq_pos[s] = -1;
324
336
  }
325
337
 
@@ -355,39 +367,38 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
355
367
  clear();
356
368
  split_reset();
357
369
 
358
- ubatches.emplace_back();
359
-
360
- auto & ubatch = ubatches.back();
370
+ auto udata = std::make_shared<llama_ubatch::data_t>();
361
371
 
362
- ubatch.token .resize(n_tokens);
363
- ubatch.embd .clear();
364
- ubatch.pos .resize(n_tokens);
365
- ubatch.n_seq_id .resize(n_tokens);
366
- ubatch.seq_id .resize(n_tokens);
367
- ubatch.seq_id_unq.resize(0);
368
- ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
369
- ubatch.output .resize(n_tokens);
372
+ udata->token .resize(n_tokens);
373
+ udata->embd .clear();
374
+ udata->pos .resize(n_tokens);
375
+ udata->n_seq_id .resize(n_tokens);
376
+ udata->seq_id .resize(n_tokens);
377
+ udata->seq_id_unq.resize(0);
378
+ udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
379
+ udata->output .resize(n_tokens);
370
380
 
371
381
  for (uint32_t s = 0; s < n_seqs; ++s) {
372
- ubatch.seq_idx[s] = s;
373
- ubatch.seq_id_unq.push_back(s);
382
+ udata->seq_idx[s] = s;
383
+ udata->seq_id_unq.push_back(s);
374
384
  }
375
385
 
376
386
  llama_ubatch res {
377
- /*.equal_seqs =*/ true,
387
+ /*.b_equal_seqs =*/ true,
378
388
  /*.n_tokens =*/ n_tokens,
379
389
  /*.n_seq_tokens =*/ n_seq_tokens,
380
390
  /*.n_seqs =*/ n_seqs,
381
391
  /*.n_seqs_unq =*/ n_seqs,
382
392
 
383
- /*.token =*/ ubatch.token.data(),
393
+ /*.token =*/ udata->token.data(),
384
394
  /*.embd =*/ nullptr,
385
- /*.pos =*/ ubatch.pos.data(),
386
- /*.n_seq_id =*/ ubatch.n_seq_id.data(),
387
- /*.seq_id =*/ ubatch.seq_id.data(),
388
- /*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
389
- /*.seq_idx =*/ ubatch.seq_idx.data(),
390
- /*.output =*/ ubatch.output.data(),
395
+ /*.pos =*/ udata->pos.data(),
396
+ /*.n_seq_id =*/ udata->n_seq_id.data(),
397
+ /*.seq_id =*/ udata->seq_id.data(),
398
+ /*.seq_id_unq =*/ udata->seq_id_unq.data(),
399
+ /*.seq_idx =*/ udata->seq_idx.data(),
400
+ /*.output =*/ udata->output.data(),
401
+ /*.data =*/ std::move(udata),
391
402
  };
392
403
 
393
404
  return res;
@@ -405,6 +416,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
405
416
  return n_outputs;
406
417
  }
407
418
 
419
+ uint32_t llama_batch_allocr::get_n_used() const {
420
+ return n_used;
421
+ }
422
+
408
423
  std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
409
424
  return out_ids;
410
425
  }
@@ -420,10 +435,10 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
420
435
  void llama_batch_allocr::split_reset() {
421
436
  out_ids.clear();
422
437
 
438
+ n_used = 0;
439
+
423
440
  used.clear();
424
441
  used.resize(get_n_tokens(), false);
425
-
426
- ubatches.clear();
427
442
  }
428
443
 
429
444
  llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
@@ -444,6 +459,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
444
459
  idxs.push_back(cur_idx);
445
460
 
446
461
  used[cur_idx] = true;
462
+ ++n_used;
447
463
 
448
464
  ++cur_idx;
449
465
 
@@ -459,9 +475,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
459
475
  return ubatch_add(idxs, idxs.size(), false);
460
476
  }
461
477
 
462
- llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
478
+ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
479
+ if (sequential && has_cpl) {
480
+ LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch (you may need to use the -kvu flag)\n", __func__);
481
+
482
+ return {};
483
+ }
484
+
463
485
  std::vector<seq_set_t> cur_seq_set;
464
486
 
487
+ llama_seq_id last_seq_id = -1;
488
+
465
489
  // determine the non-overlapping sequence sets participating in this ubatch
466
490
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
467
491
  if (used[i]) {
@@ -478,9 +502,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
478
502
  }
479
503
  }
480
504
 
505
+ // accept only increasing sequence ids
506
+ if (sequential) {
507
+ add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
508
+ }
509
+
481
510
  if (add) {
482
511
  cur_seq_set.push_back(seq_set[i]);
483
512
 
513
+ last_seq_id = batch.seq_id[i][0];
514
+
484
515
  if (cur_seq_set.size() > n_ubatch) {
485
516
  break;
486
517
  }
@@ -529,6 +560,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
529
560
  idxs_per_seq[s].push_back(idx);
530
561
 
531
562
  used[idx] = true;
563
+ ++n_used;
532
564
 
533
565
  ++cur_idx[s];
534
566
  }
@@ -570,6 +602,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
570
602
  idxs.push_back(cur_idx);
571
603
 
572
604
  used[cur_idx] = true;
605
+ ++n_used;
573
606
 
574
607
  if (idxs.size() >= n_ubatch) {
575
608
  break;
@@ -620,78 +653,77 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
620
653
 
621
654
  assert(n_tokens%n_seqs == 0);
622
655
 
623
- ubatches.emplace_back();
624
-
625
- auto & ubatch = ubatches.back();
656
+ auto udata = std::make_shared<llama_ubatch::data_t>();
626
657
 
627
658
  const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
628
659
 
629
660
  const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
630
661
  const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
631
662
 
632
- ubatch.token .resize(n_tokens);
633
- ubatch.embd .resize(n_embd_all);
634
- ubatch.pos .resize(n_pos_all);
635
- ubatch.n_seq_id .resize(n_tokens);
636
- ubatch.seq_id .resize(n_tokens);
637
- ubatch.seq_id_unq.resize(0);
638
- ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
639
- ubatch.output .resize(n_tokens);
663
+ udata->token .resize(n_tokens);
664
+ udata->embd .resize(n_embd_all);
665
+ udata->pos .resize(n_pos_all);
666
+ udata->n_seq_id .resize(n_tokens);
667
+ udata->seq_id .resize(n_tokens);
668
+ udata->seq_id_unq.resize(0);
669
+ udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
670
+ udata->output .resize(n_tokens);
640
671
 
641
672
  seq_set_t seq_set_unq;
642
673
 
643
674
  for (size_t i = 0; i < idxs.size(); ++i) {
644
675
  if (batch.token) {
645
- ubatch.token[i] = batch.token[idxs[i]];
676
+ udata->token[i] = batch.token[idxs[i]];
646
677
  }
647
678
 
648
679
  if (batch.embd) {
649
- memcpy(ubatch.embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
680
+ memcpy(udata->embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
650
681
  }
651
682
 
652
683
  for (int j = 0; j < n_pos_cur; ++j) {
653
- ubatch.pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
684
+ udata->pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
654
685
  }
655
686
 
656
- ubatch.n_seq_id[i] = batch.n_seq_id[idxs[i]];
657
- ubatch.seq_id[i] = batch.seq_id[idxs[i]];
658
- ubatch.output[i] = batch.logits[idxs[i]];
687
+ udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
688
+ udata->seq_id[i] = batch.seq_id[idxs[i]];
689
+ udata->output[i] = batch.logits[idxs[i]];
659
690
 
660
- for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
661
- seq_set_unq.set(ubatch.seq_id[i][s]);
691
+ for (int s = 0; s < udata->n_seq_id[i]; ++s) {
692
+ seq_set_unq.set(udata->seq_id[i][s]);
662
693
  }
663
694
 
664
- if (ubatch.output[i]) {
695
+ if (udata->output[i]) {
665
696
  out_ids.push_back(idxs[i]);
666
697
  }
667
698
  }
668
699
 
669
- for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
700
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
670
701
  if (seq_set_unq.test(s)) {
671
- ubatch.seq_idx[s] = ubatch.seq_id_unq.size();
672
- ubatch.seq_id_unq.push_back(s);
702
+ udata->seq_idx[s] = udata->seq_id_unq.size();
703
+ udata->seq_id_unq.push_back(s);
673
704
  }
674
705
  }
675
706
 
676
707
  llama_ubatch res {
677
- /*.equal_seqs =*/ equal_seqs,
708
+ /*.b_equal_seqs =*/ equal_seqs,
678
709
  /*.n_tokens =*/ n_tokens,
679
710
  /*.n_seq_tokens =*/ n_tokens/n_seqs,
680
711
  /*.n_seqs =*/ n_seqs,
681
- /*.n_seqs_unq =*/ (uint32_t) ubatch.seq_id_unq.size(),
682
-
683
- /*.token =*/ batch.token ? ubatch.token.data() : nullptr,
684
- /*.embd =*/ batch.embd ? ubatch.embd.data() : nullptr,
685
- /*.pos =*/ ubatch.pos.data(),
686
- /*.n_seq_id =*/ ubatch.n_seq_id.data(),
687
- /*.seq_id =*/ ubatch.seq_id.data(),
688
- /*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
689
- /*.seq_idx =*/ ubatch.seq_idx.data(),
690
- /*.output =*/ ubatch.output.data(),
712
+ /*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
713
+
714
+ /*.token =*/ batch.token ? udata->token.data() : nullptr,
715
+ /*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
716
+ /*.pos =*/ udata->pos.data(),
717
+ /*.n_seq_id =*/ udata->n_seq_id.data(),
718
+ /*.seq_id =*/ udata->seq_id.data(),
719
+ /*.seq_id_unq =*/ udata->seq_id_unq.data(),
720
+ /*.seq_idx =*/ udata->seq_idx.data(),
721
+ /*.output =*/ udata->output.data(),
722
+ /*.data =*/ std::move(udata),
691
723
  };
692
724
 
693
725
  if (debug > 0) {
694
- LLAMA_LOG_DEBUG("%s: added ubatch %d to split:\n", __func__, (int) ubatches.size() - 1);
726
+ LLAMA_LOG_DEBUG("%s: added ubatch to split:\n", __func__);
695
727
 
696
728
  ubatch_print(res, debug);
697
729
  }
@@ -701,7 +733,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
701
733
 
702
734
  void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
703
735
  if (debug > 0) {
704
- LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs);
736
+ LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs());
705
737
  LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, ubatch.n_tokens);
706
738
  LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
707
739
  LLAMA_LOG_DEBUG("%s: n_seqs = %d\n", __func__, ubatch.n_seqs);
@@ -8,12 +8,17 @@
8
8
  #include <vector>
9
9
  #include <set>
10
10
  #include <bitset>
11
+ #include <memory>
11
12
  #include <unordered_map>
12
13
 
13
14
  // keep this struct lightweight
14
- // it points to data in `llama_batch_allocr`
15
15
  struct llama_ubatch {
16
- bool equal_seqs;
16
+ bool equal_seqs() const {
17
+ return b_equal_seqs != 0;
18
+ }
19
+
20
+ uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
21
+ // otherwise address sanitizer complains
17
22
  // TODO: whole_seqs for embeddings?
18
23
 
19
24
  uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
@@ -34,6 +39,20 @@ struct llama_ubatch {
34
39
  llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
35
40
  int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
36
41
  int8_t * output; // [n_tokens] | i | -
42
+
43
+ struct data_t {
44
+ std::vector<llama_token> token;
45
+ std::vector<float> embd;
46
+ std::vector<llama_pos> pos;
47
+ std::vector<int32_t> n_seq_id;
48
+ std::vector<llama_seq_id *> seq_id;
49
+ std::vector<llama_seq_id> seq_id_unq;
50
+ std::vector<int32_t> seq_idx;
51
+ std::vector<int8_t> output;
52
+ };
53
+
54
+ // the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data
55
+ std::shared_ptr<data_t> data;
37
56
  };
38
57
 
39
58
  // a helper for sanitizing, fulfilling and splitting a batch
@@ -48,12 +67,14 @@ public:
48
67
  const llama_vocab & vocab,
49
68
  const llama_memory_i * memory,
50
69
  uint32_t n_embd,
70
+ uint32_t n_seq_max,
51
71
  bool output_all);
52
72
 
53
73
  const llama_batch & get_batch() const;
54
74
 
55
75
  uint32_t get_n_tokens() const;
56
76
  uint32_t get_n_outputs() const;
77
+ uint32_t get_n_used() const;
57
78
 
58
79
  // the array of output indices in the order they were encountered during the ubatch splitting
59
80
  std::vector<int32_t> & get_out_ids();
@@ -69,7 +90,8 @@ public:
69
90
  llama_ubatch split_simple(uint32_t n_ubatch);
70
91
 
71
92
  // make ubatches of equal-length sequences sets
72
- llama_ubatch split_equal(uint32_t n_ubatch);
93
+ // if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
94
+ llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
73
95
 
74
96
  // sequence-set-wise split - each ubatch contains a single sequence-set
75
97
  llama_ubatch split_seq(uint32_t n_ubatch);
@@ -98,6 +120,7 @@ private:
98
120
  const uint32_t n_pos_per_embd;
99
121
 
100
122
  uint32_t n_embd;
123
+ uint32_t n_seq_max;
101
124
  uint32_t n_outputs;
102
125
 
103
126
  std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
@@ -112,6 +135,9 @@ private:
112
135
  using pos_set_t = std::set<llama_pos>;
113
136
  using seq_cpl_t = std::vector<bool>;
114
137
 
138
+ // helper flag to quickly determine if there are any coupled sequences in the batch
139
+ bool has_cpl = false;
140
+
115
141
  std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
116
142
  std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
117
143
 
@@ -125,23 +151,10 @@ private:
125
151
  // batch indices of the output
126
152
  std::vector<int32_t> out_ids;
127
153
 
154
+ uint32_t n_used;
155
+
128
156
  // used[i] indicates if token i has already been used in a previous ubatch
129
157
  std::vector<bool> used;
130
158
 
131
- // llama_ubatch points to this data:
132
- struct ubatch {
133
- std::vector<llama_token> token;
134
- std::vector<float> embd;
135
- std::vector<llama_pos> pos;
136
- std::vector<int32_t> n_seq_id;
137
- std::vector<llama_seq_id *> seq_id;
138
- std::vector<llama_seq_id> seq_id_unq;
139
- std::vector<int32_t> seq_idx;
140
- std::vector<int8_t> output;
141
- };
142
-
143
- // current splitting state:
144
- std::vector<ubatch> ubatches;
145
-
146
159
  int debug;
147
160
  };