whispercpp 1.3.5 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1017) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/LICENSE +1 -1
  5. data/README.md +133 -3
  6. data/Rakefile +18 -3
  7. data/ext/dependencies.rb +10 -4
  8. data/ext/dependencies_for_windows.rb +17 -0
  9. data/ext/extconf.rb +20 -7
  10. data/ext/options.rb +54 -14
  11. data/ext/options_for_windows.rb +51 -0
  12. data/ext/ruby_whisper.c +56 -46
  13. data/ext/ruby_whisper.h +165 -2
  14. data/ext/ruby_whisper_context.c +297 -126
  15. data/ext/ruby_whisper_context_params.c +163 -0
  16. data/ext/ruby_whisper_log_queue.c +180 -0
  17. data/ext/ruby_whisper_log_settable.h +47 -0
  18. data/ext/ruby_whisper_model.c +0 -1
  19. data/ext/ruby_whisper_parakeet.c +49 -0
  20. data/ext/ruby_whisper_parakeet_context.c +304 -0
  21. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  22. data/ext/ruby_whisper_parakeet_model.c +84 -0
  23. data/ext/ruby_whisper_parakeet_params.c +548 -0
  24. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  25. data/ext/ruby_whisper_parakeet_token.c +188 -0
  26. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  27. data/ext/ruby_whisper_params.c +256 -66
  28. data/ext/ruby_whisper_segment.c +6 -7
  29. data/ext/ruby_whisper_token.c +29 -9
  30. data/ext/ruby_whisper_transcribe.cpp +46 -16
  31. data/ext/ruby_whisper_vad_context.c +48 -1
  32. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  33. data/ext/ruby_whisper_vad_params.c +0 -1
  34. data/ext/ruby_whisper_vad_segment.c +0 -1
  35. data/ext/ruby_whisper_vad_segments.c +0 -1
  36. data/ext/sources/CMakeLists.txt +41 -3
  37. data/ext/sources/CMakePresets.json +95 -0
  38. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  39. data/ext/sources/cmake/parakeet.pc.in +10 -0
  40. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  41. data/ext/sources/cmake/whisper.pc.in +1 -1
  42. data/ext/sources/examples/CMakeLists.txt +4 -2
  43. data/ext/sources/examples/bench/bench.cpp +24 -19
  44. data/ext/sources/examples/cli/cli.cpp +51 -9
  45. data/ext/sources/examples/common-ggml.cpp +4 -0
  46. data/ext/sources/examples/common-whisper.cpp +139 -67
  47. data/ext/sources/examples/common-whisper.h +11 -0
  48. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  49. data/ext/sources/examples/miniaudio.h +4507 -2131
  50. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  51. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  52. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  53. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  54. data/ext/sources/examples/server/server.cpp +213 -163
  55. data/ext/sources/ggml/CMakeLists.txt +29 -15
  56. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  57. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  58. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  59. data/ext/sources/ggml/include/ggml-backend.h +73 -11
  60. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  61. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  62. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  63. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  64. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  65. data/ext/sources/ggml/include/ggml-rpc.h +8 -3
  66. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  67. data/ext/sources/ggml/include/ggml.h +155 -16
  68. data/ext/sources/ggml/include/gguf.h +10 -2
  69. data/ext/sources/ggml/src/CMakeLists.txt +25 -5
  70. data/ext/sources/ggml/src/ggml-alloc.c +9 -10
  71. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  72. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  73. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  74. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
  76. data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
  79. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
  83. data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
  85. data/ext/sources/ggml/src/ggml-common.h +22 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
  88. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  89. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
  91. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
  92. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
  99. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
  101. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  102. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  103. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
  105. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
  106. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
  107. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  108. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
  109. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
  110. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
  111. data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
  112. data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
  113. data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
  115. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  116. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
  117. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
  118. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  119. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  120. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  121. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  122. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  123. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  124. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  125. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  126. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  127. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  128. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  129. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  130. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  131. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  132. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  133. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  134. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
  135. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  136. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
  137. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  138. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  139. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
  140. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
  141. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  142. data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
  143. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  144. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  145. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  146. data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
  147. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  148. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
  149. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  150. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
  151. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
  152. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  153. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
  154. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
  155. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
  156. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  157. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
  158. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  159. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  160. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  161. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
  162. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  163. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  164. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
  165. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  166. data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
  167. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
  168. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  169. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
  170. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  171. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
  172. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
  173. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  174. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
  175. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
  176. data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
  177. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  178. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  179. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
  180. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  181. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
  182. data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
  183. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  184. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  185. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  186. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  187. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  188. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  189. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  190. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
  191. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  192. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  193. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
  226. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
  227. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  228. data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
  229. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
  230. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
  231. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  232. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
  233. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
  235. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
  254. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  255. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  256. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  257. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  258. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  259. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  260. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  261. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  262. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
  263. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
  264. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  265. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  266. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
  267. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
  268. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
  269. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  270. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
  271. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  272. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  273. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  274. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  275. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  276. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  278. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  297. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
  302. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
  303. data/ext/sources/ggml/src/ggml-impl.h +68 -1
  304. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  305. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  306. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  307. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  308. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
  309. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
  310. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
  311. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
  312. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
  313. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
  314. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
  315. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
  316. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  317. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
  318. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  367. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  368. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  369. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  370. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  371. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  372. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  373. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  374. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  375. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  376. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  377. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  378. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  379. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  380. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  381. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  382. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  383. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  384. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  385. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  386. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  387. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  388. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  389. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
  390. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  391. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  392. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  393. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  394. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  395. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  396. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  397. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  398. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  399. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  400. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  401. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
  402. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  403. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
  404. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  405. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
  406. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
  407. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  412. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  413. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  414. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  415. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  416. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  417. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  418. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  419. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  420. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  421. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  422. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  423. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  424. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
  425. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  426. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  427. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  428. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  429. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  430. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  431. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  432. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
  433. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
  434. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  435. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  436. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  437. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  438. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  439. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  440. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
  441. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  442. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
  443. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
  444. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
  445. data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
  446. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  447. data/ext/sources/ggml/src/ggml-quants.c +385 -119
  448. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  449. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  450. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  451. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  452. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  453. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
  454. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  455. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
  456. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  457. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  458. data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
  459. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
  460. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
  461. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  462. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  463. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  464. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  465. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  466. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  467. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  468. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  469. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
  470. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
  471. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  472. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  473. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
  474. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
  475. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
  476. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
  477. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
  478. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  479. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  480. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  481. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
  482. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
  483. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  484. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  485. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
  486. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  487. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  488. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  489. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  490. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  491. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  492. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  493. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  494. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
  495. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  496. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  497. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  498. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  499. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  500. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  501. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  502. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  503. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  504. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  505. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  506. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  507. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  508. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  509. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  510. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  511. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  512. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  513. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  514. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  515. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  516. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
  517. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
  518. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
  519. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
  520. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
  521. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
  522. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
  523. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
  524. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
  525. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
  526. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
  527. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
  530. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
  531. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
  532. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
  534. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
  535. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
  536. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
  537. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
  538. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
  539. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
  540. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
  541. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
  542. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
  543. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
  544. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
  545. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
  547. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
  548. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
  549. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
  550. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
  551. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
  552. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  553. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  554. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  555. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
  556. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  557. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  558. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  559. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  560. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  561. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  562. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  563. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  564. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  565. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  566. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  567. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  568. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  569. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  570. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  571. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  572. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  573. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  574. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  575. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  576. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  577. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  578. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
  579. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
  580. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  581. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
  582. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  583. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  584. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  585. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  586. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  587. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  588. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  589. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  590. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  591. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  592. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
  593. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  594. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  595. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  596. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
  597. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  599. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
  663. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  664. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
  665. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
  666. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  667. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  668. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  669. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  670. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  671. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
  672. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
  673. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
  674. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  675. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
  676. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  677. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
  678. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  679. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  680. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  681. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  682. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  683. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  684. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
  685. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  686. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  687. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
  688. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  689. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  690. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  691. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
  692. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
  693. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
  694. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  695. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  696. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  697. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  698. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  699. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  700. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  701. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  702. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  703. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
  704. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  705. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
  706. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  707. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  708. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  709. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  710. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  711. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  712. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
  713. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  714. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
  715. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  716. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
  717. data/ext/sources/ggml/src/ggml.c +268 -52
  718. data/ext/sources/ggml/src/gguf.cpp +377 -47
  719. data/ext/sources/include/parakeet.h +342 -0
  720. data/ext/sources/include/whisper.h +10 -0
  721. data/ext/sources/media/matmul.png +0 -0
  722. data/ext/sources/src/CMakeLists.txt +23 -0
  723. data/ext/sources/src/parakeet-arch.h +188 -0
  724. data/ext/sources/src/parakeet.cpp +3838 -0
  725. data/ext/sources/src/whisper.cpp +62 -40
  726. data/extsources.rb +26 -10
  727. data/lib/whisper/log_settable.rb +36 -0
  728. data/lib/whisper/model/uri.rb +13 -1
  729. data/lib/whisper/output.rb +74 -0
  730. data/sig/whisper.rbs +445 -55
  731. data/test/helper.rb +2 -0
  732. data/test/jfk_reader/jfk_reader.c +50 -7
  733. data/test/test_callback.rb +1 -0
  734. data/test/test_context_params.rb +82 -0
  735. data/test/test_package.rb +6 -5
  736. data/test/test_parakeet.rb +28 -0
  737. data/test/test_parakeet_callback.rb +107 -0
  738. data/test/test_parakeet_context.rb +116 -0
  739. data/test/test_parakeet_context_params.rb +24 -0
  740. data/test/test_parakeet_model.rb +21 -0
  741. data/test/test_parakeet_params.rb +78 -0
  742. data/test/test_parakeet_segment.rb +42 -0
  743. data/test/test_parakeet_token.rb +73 -0
  744. data/test/test_params.rb +2 -0
  745. data/test/test_token.rb +11 -0
  746. data/test/test_vad_context.rb +58 -8
  747. data/test/test_vad_segment.rb +1 -1
  748. data/test/test_whisper.rb +44 -6
  749. data/whispercpp.gemspec +2 -2
  750. metadata +426 -280
  751. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  752. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  753. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  754. data/ext/sources/bindings/javascript/package.json +0 -26
  755. data/ext/sources/bindings/javascript/whisper.js +0 -19
  756. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  757. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  758. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  759. data/ext/sources/examples/addon.node/index.js +0 -59
  760. data/ext/sources/examples/addon.node/package.json +0 -16
  761. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  762. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  763. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  764. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  765. data/ext/sources/examples/coi-serviceworker.js +0 -146
  766. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  767. data/ext/sources/examples/command/command.cpp +0 -802
  768. data/ext/sources/examples/command/commands.txt +0 -9
  769. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  770. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  771. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  772. data/ext/sources/examples/generate-karaoke.sh +0 -57
  773. data/ext/sources/examples/helpers.js +0 -191
  774. data/ext/sources/examples/livestream.sh +0 -112
  775. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  776. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  777. data/ext/sources/examples/lsp/whisper.vim +0 -362
  778. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  779. data/ext/sources/examples/python/whisper_processor.py +0 -54
  780. data/ext/sources/examples/server/bench.js +0 -29
  781. data/ext/sources/examples/server.py +0 -120
  782. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  783. data/ext/sources/examples/stream/stream.cpp +0 -437
  784. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  785. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  786. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  787. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  788. data/ext/sources/examples/sycl/build.sh +0 -22
  789. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  790. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  791. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
  792. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  793. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
  794. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
  795. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
  796. data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
  797. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
  798. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  799. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
  800. data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
  801. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
  802. data/ext/sources/examples/talk-llama/llama-context.h +0 -360
  803. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  804. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
  805. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  806. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  807. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
  808. data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
  809. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
  810. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
  811. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  812. data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
  813. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  814. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  815. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
  816. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  817. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
  818. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
  819. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  820. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  821. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  822. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
  823. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  824. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  825. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  826. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
  827. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  828. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
  829. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
  830. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
  831. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
  832. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
  833. data/ext/sources/examples/talk-llama/llama-model.h +0 -544
  834. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
  835. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  836. data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
  837. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
  838. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
  839. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
  840. data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
  841. data/ext/sources/examples/talk-llama/llama.h +0 -1540
  842. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
  843. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  844. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  845. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
  846. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  847. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
  848. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
  849. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
  850. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
  851. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
  852. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  853. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  854. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  855. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  856. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  857. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  858. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  859. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
  860. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  861. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
  862. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
  863. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
  864. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  865. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
  866. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  867. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  868. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  869. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
  870. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  871. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  872. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  873. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  874. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  875. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  876. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  877. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
  878. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  879. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  880. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
  881. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
  882. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
  883. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  884. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
  885. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  886. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
  887. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  888. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  889. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  890. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
  891. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  892. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  893. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  894. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
  895. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  896. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
  897. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
  898. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
  899. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
  900. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  901. data/ext/sources/examples/talk-llama/models/models.h +0 -569
  902. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
  903. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  904. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
  905. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  906. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  907. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  908. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  909. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  910. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  911. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  912. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  913. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  914. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  915. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  916. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  917. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
  918. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  919. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
  920. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  921. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  922. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  923. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  924. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
  925. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
  926. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
  927. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
  928. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
  929. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  930. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  931. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
  932. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  933. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  934. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
  935. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  936. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  937. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  938. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  939. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  940. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  941. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  942. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  943. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  944. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  945. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  946. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  947. data/ext/sources/examples/talk-llama/speak +0 -40
  948. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  949. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  950. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  951. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  952. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  953. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
  954. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  955. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  956. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  957. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  958. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  959. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  960. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  961. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  962. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  963. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  964. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  965. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  966. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  967. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  968. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  969. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  970. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
  971. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
  972. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  973. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  974. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  975. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  976. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  977. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  978. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  979. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  980. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  981. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  982. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  983. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  984. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
  985. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  986. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
  987. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
  988. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  989. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  990. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  991. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
  992. data/ext/sources/tests/CMakeLists.txt +0 -112
  993. data/ext/sources/tests/earnings21/eval.mk +0 -58
  994. data/ext/sources/tests/earnings21/eval.py +0 -68
  995. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  996. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  997. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  998. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  999. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  1000. data/ext/sources/tests/en-0-ref.txt +0 -1
  1001. data/ext/sources/tests/en-1-ref.txt +0 -1
  1002. data/ext/sources/tests/en-2-ref.txt +0 -1
  1003. data/ext/sources/tests/es-0-ref.txt +0 -1
  1004. data/ext/sources/tests/librispeech/eval.mk +0 -39
  1005. data/ext/sources/tests/librispeech/eval.py +0 -47
  1006. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  1007. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  1008. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  1009. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  1010. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  1011. data/ext/sources/tests/run-tests.sh +0 -130
  1012. data/ext/sources/tests/test-c.c +0 -3
  1013. data/ext/sources/tests/test-vad-full.cpp +0 -56
  1014. data/ext/sources/tests/test-vad.cpp +0 -83
  1015. data/ext/sources/tests/test-whisper.js +0 -58
  1016. data/lib/whisper/context.rb +0 -15
  1017. data/lib/whisper/segment.rb +0 -58
@@ -1,3771 +0,0 @@
1
- #include "llama-sampling.h"
2
-
3
- #include "llama-impl.h"
4
- #include "llama-vocab.h"
5
- #include "llama-grammar.h"
6
-
7
- #include "ggml-cpp.h"
8
-
9
- #include <array>
10
- #include <algorithm>
11
- #include <cassert>
12
- #include <cfloat>
13
- #include <chrono>
14
- #include <cmath>
15
- #include <cstdlib>
16
- #include <cstring>
17
- #include <ctime>
18
- #include <numeric>
19
- #include <random>
20
- #include <unordered_map>
21
- #include <stdexcept>
22
-
23
- // the ring buffer works similarly to std::deque, but with a fixed capacity
24
- template<typename T>
25
- struct ring_buffer {
26
- ring_buffer(size_t cap) : capacity(cap), data(cap) {}
27
-
28
- T & front() {
29
- if (sz == 0) {
30
- throw std::runtime_error("ring buffer is empty");
31
- }
32
- return data[first];
33
- }
34
-
35
- const T & front() const {
36
- if (sz == 0) {
37
- throw std::runtime_error("ring buffer is empty");
38
- }
39
- return data[first];
40
- }
41
-
42
- T & back() {
43
- if (sz == 0) {
44
- throw std::runtime_error("ring buffer is empty");
45
- }
46
- return data[pos];
47
- }
48
-
49
- const T & back() const {
50
- if (sz == 0) {
51
- throw std::runtime_error("ring buffer is empty");
52
- }
53
- return data[pos];
54
- }
55
-
56
- void push_back(const T & value) {
57
- if (capacity == 0) {
58
- throw std::runtime_error("ring buffer: capacity is zero");
59
- }
60
-
61
- if (sz == capacity) {
62
- // advance the start when buffer is full
63
- first = (first + 1) % capacity;
64
- } else {
65
- sz++;
66
- }
67
- data[pos] = value;
68
- pos = (pos + 1) % capacity;
69
- }
70
-
71
- T pop_front() {
72
- if (sz == 0) {
73
- throw std::runtime_error("ring buffer is empty");
74
- }
75
- T value = data[first];
76
- first = (first + 1) % capacity;
77
- sz--;
78
- return value;
79
- }
80
-
81
- //T & operator[](size_t i) {
82
- // if (i >= sz) {
83
- // throw std::runtime_error("ring buffer: index out of bounds");
84
- // }
85
- // return data[(first + i) % capacity];
86
- //}
87
-
88
- //const T & at(size_t i) const {
89
- // if (i >= sz) {
90
- // throw std::runtime_error("ring buffer: index out of bounds");
91
- // }
92
- // return data[(first + i) % capacity];
93
- //}
94
-
95
- const T & rat(size_t i) const {
96
- if (i >= sz) {
97
- throw std::runtime_error("ring buffer: index out of bounds");
98
- }
99
- return data[(first + sz - i - 1) % capacity];
100
- }
101
-
102
- std::vector<T> to_vector() const {
103
- std::vector<T> result;
104
- result.reserve(sz);
105
- for (size_t i = 0; i < sz; i++) {
106
- result.push_back(data[(first + i) % capacity]);
107
- }
108
- return result;
109
- }
110
-
111
- void clear() {
112
- // here only reset the status of the buffer
113
- sz = 0;
114
- first = 0;
115
- pos = 0;
116
- }
117
-
118
- bool empty() const {
119
- return sz == 0;
120
- }
121
-
122
- size_t size() const {
123
- return sz;
124
- }
125
-
126
- size_t capacity = 0;
127
- size_t sz = 0;
128
- size_t first = 0;
129
- size_t pos = 0;
130
-
131
- std::vector<T> data;
132
- };
133
-
134
- // writes result in res, does not mutate cur
135
- static void llama_token_data_array_partial_sort(const llama_token_data_array & cur, int npartial, std::vector<llama_token_data> & res) {
136
- static const auto comp = [](const llama_token_data & a, const llama_token_data & b) {
137
- return a.logit > b.logit;
138
- };
139
-
140
- constexpr int nbuckets = 128;
141
- constexpr float bucket_low = -10.0f;
142
- constexpr float bucket_high = 10.0f;
143
- constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
144
- constexpr float bucket_inter = -bucket_low * bucket_scale;
145
-
146
- std::vector<int> bucket_idx;
147
- std::vector<int> histo(nbuckets, 0);
148
-
149
- std::vector<llama_token_data*> bucket_ptrs;
150
-
151
- bucket_idx.reserve(cur.size);
152
-
153
- for (int i = 0; i < (int)cur.size; ++i) {
154
- const float val = cur.data[i].logit;
155
- int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
156
- ib = std::max(0, std::min(nbuckets - 1, ib));
157
- bucket_idx.push_back(ib);
158
- ++histo[ib];
159
- }
160
- int nhave = 0;
161
- int ib = nbuckets - 1;
162
- for ( ; ib >= 0; --ib) {
163
- nhave += histo[ib];
164
- if (nhave >= npartial) {
165
- break;
166
- }
167
- }
168
- res.resize(nhave);
169
- auto * ptr = res.data();
170
- bucket_ptrs.reserve(nbuckets - ib);
171
- for (int j = nbuckets - 1; j >= ib; --j) {
172
- bucket_ptrs.push_back(ptr);
173
- ptr += histo[j];
174
- }
175
- for (int i = 0; i < (int)cur.size; ++i) {
176
- int j = bucket_idx[i];
177
- if (j >= ib) {
178
- *bucket_ptrs[nbuckets - 1 - j]++ = cur.data[i];
179
- }
180
- }
181
-
182
- ptr = res.data();
183
- int ndone = 0;
184
- for (int j = nbuckets - 1; j > ib; --j) {
185
- std::sort(ptr, ptr + histo[j], comp);
186
- ptr += histo[j];
187
- ndone += histo[j];
188
- }
189
- std::partial_sort(ptr, ptr + npartial - ndone, ptr + histo[ib], comp);
190
- }
191
-
192
- // reduces the size of cur_p to npartial, keeping only the top npartial elements
193
- static void llama_token_data_array_partial_sort_inplace(llama_token_data_array * cur_p, int npartial) {
194
- static const auto comp = [](const llama_token_data & a, const llama_token_data & b) {
195
- return a.logit > b.logit;
196
- };
197
-
198
- if (npartial <= 128) {
199
- std::partial_sort(cur_p->data, cur_p->data + npartial, cur_p->data + cur_p->size, comp);
200
-
201
- cur_p->size = npartial;
202
- cur_p->sorted = true;
203
-
204
- return;
205
- }
206
-
207
- std::vector<llama_token_data> tmp;
208
-
209
- llama_token_data_array_partial_sort(*cur_p, npartial, tmp);
210
-
211
- std::copy(tmp.data(), tmp.data() + npartial, cur_p->data);
212
-
213
- cur_p->size = npartial;
214
- cur_p->sorted = true;
215
- }
216
-
217
- static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
218
- // iterator for the probabilities
219
- #ifdef __GNUC__
220
- #pragma GCC diagnostic push
221
- #pragma GCC diagnostic ignored "-Wunused-local-typedefs"
222
- #endif
223
-
224
- struct probs_iterator {
225
- typedef std::input_iterator_tag iterator_category;
226
- typedef float value_type;
227
- typedef float * pointer;
228
- typedef float & reference;
229
- typedef ptrdiff_t difference_type;
230
-
231
- const llama_token_data * data;
232
-
233
- bool operator==(const probs_iterator & other) const { return data == other.data; }
234
- bool operator!=(const probs_iterator & other) const { return data != other.data; }
235
- const float & operator*() const { return data->p; }
236
- probs_iterator & operator++() { ++data; return *this; }
237
- probs_iterator operator++(int) { probs_iterator tmp = *this; ++data; return tmp; }
238
- };
239
-
240
- #ifdef __GNUC__
241
- #pragma GCC diagnostic pop
242
- #endif
243
-
244
- std::discrete_distribution<int> dist(probs_iterator{cur_p->data}, probs_iterator{cur_p->data + cur_p->size});
245
-
246
- return dist(rng);
247
- }
248
-
249
- /*
250
- static void llama_log_softmax(float * array, size_t size) {
251
- float max_l = *std::max_element(array, array + size);
252
- float sum = 0.f;
253
- for (size_t i = 0; i < size; ++i) {
254
- float p = expf(array[i] - max_l);
255
- sum += p;
256
- array[i] = p;
257
- }
258
-
259
- for (size_t i = 0; i < size; ++i) {
260
- array[i] = logf(array[i] / sum);
261
- }
262
- }
263
- */
264
-
265
- static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
266
- if (temp <= 0.0f) {
267
- // find the token with the highest logit and set the rest to -inf
268
- size_t max_i = 0;
269
- float max_l = cur_p->data[0].logit;
270
-
271
- for (size_t i = 1; i < cur_p->size; ++i) {
272
- if (cur_p->data[i ].logit > max_l) {
273
- cur_p->data[max_i].logit = -INFINITY;
274
- max_i = i;
275
- max_l = cur_p->data[i].logit;
276
- } else {
277
- cur_p->data[i].logit = -INFINITY;
278
- }
279
- }
280
-
281
- return;
282
- }
283
-
284
- for (size_t i = 0; i < cur_p->size; ++i) {
285
- cur_p->data[i].logit /= temp;
286
- }
287
- }
288
-
289
- static void llama_sampler_softmax_impl(llama_token_data_array * cur_p, bool do_sort) {
290
- GGML_ASSERT(cur_p->size > 0);
291
-
292
- // Sort the logits in descending order if requested
293
- if (do_sort && !cur_p->sorted) {
294
- llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size);
295
- }
296
-
297
- float max_l = cur_p->data[0].logit;
298
- if (!cur_p->sorted) {
299
- for (size_t i = 1; i < cur_p->size; ++i) {
300
- max_l = std::max(max_l, cur_p->data[i].logit);
301
- }
302
- }
303
-
304
- float cum_sum = 0.0f;
305
-
306
- for (size_t i = 0; i < cur_p->size; ++i) {
307
- float p = expf(cur_p->data[i].logit - max_l);
308
- cur_p->data[i].p = p;
309
- cum_sum += p;
310
- }
311
-
312
- for (size_t i = 0; i < cur_p->size; ++i) {
313
- cur_p->data[i].p /= cum_sum;
314
- }
315
- }
316
-
317
- static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
318
- // if (k >= (int32_t)cur_p->size) {
319
- // return;
320
- // }
321
-
322
- if (k <= 0) {
323
- return;
324
- }
325
-
326
- k = std::min(k, (int) cur_p->size);
327
-
328
- // Sort scores in descending order
329
- if (!cur_p->sorted) {
330
- llama_token_data_array_partial_sort_inplace(cur_p, k);
331
- }
332
-
333
- cur_p->size = k;
334
- }
335
-
336
- static uint32_t get_rng_seed(uint32_t seed) {
337
- if (seed == LLAMA_DEFAULT_SEED) {
338
- // use system clock if std::random_device is not a true RNG
339
- static bool is_rd_prng = std::random_device().entropy() == 0;
340
- if (is_rd_prng) {
341
- return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count();
342
- }
343
- std::random_device rd;
344
- return rd();
345
- }
346
- return seed;
347
- }
348
-
349
- // llama_sampler API
350
-
351
- struct llama_sampler * llama_sampler_init(
352
- struct llama_sampler_i * iface,
353
- llama_sampler_context_t ctx) {
354
- return new llama_sampler {
355
- /* .iface = */ iface,
356
- /* .ctx = */ ctx,
357
- };
358
- }
359
-
360
- const char * llama_sampler_name(const struct llama_sampler * smpl) {
361
- if (!smpl->iface) {
362
- return "(null)";
363
- }
364
-
365
- return smpl->iface->name(smpl);
366
- }
367
-
368
- void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
369
- if (!smpl) {
370
- return;
371
- }
372
-
373
- if (smpl->iface->accept) {
374
- smpl->iface->accept(smpl, token);
375
- }
376
- }
377
-
378
- void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
379
- if (!smpl) {
380
- return;
381
- }
382
-
383
- GGML_ASSERT(smpl->iface->apply);
384
- smpl->iface->apply(smpl, cur_p);
385
- }
386
-
387
- void llama_sampler_reset(struct llama_sampler * smpl) {
388
- if (!smpl) {
389
- return;
390
- }
391
-
392
- if (smpl->iface->reset) {
393
- smpl->iface->reset(smpl);
394
- }
395
- }
396
-
397
- struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
398
- if (!smpl) {
399
- return nullptr;
400
- }
401
-
402
- if (smpl->iface->clone) {
403
- return smpl->iface->clone(smpl);
404
- }
405
-
406
- if (smpl->ctx == nullptr) {
407
- return llama_sampler_init(
408
- /* .iface = */ smpl->iface,
409
- /* .ctx = */ nullptr
410
- );
411
- }
412
-
413
- GGML_ABORT("the sampler does not support cloning");
414
- }
415
-
416
- void llama_sampler_free(struct llama_sampler * smpl) {
417
- if (smpl == nullptr) {
418
- return;
419
- }
420
-
421
- if (smpl->iface->free) {
422
- smpl->iface->free(smpl);
423
- }
424
-
425
- delete smpl;
426
- }
427
-
428
- // empty sampler
429
-
430
- struct llama_sampler_empty {
431
- const char * name;
432
- };
433
-
434
- static struct llama_sampler * llama_sampler_init_empty(const char * name);
435
-
436
- static const char * llama_sampler_empty_name(const struct llama_sampler * smpl) {
437
- auto * ctx = (llama_sampler_empty *) smpl->ctx;
438
- return ctx->name;
439
- }
440
-
441
- static void llama_sampler_empty_accept(struct llama_sampler * smpl, llama_token token) {
442
- GGML_UNUSED(smpl);
443
- GGML_UNUSED(token);
444
- }
445
-
446
- static void llama_sampler_empty_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
447
- GGML_UNUSED(smpl);
448
- GGML_UNUSED(cur_p);
449
- }
450
-
451
- static void llama_sampler_empty_reset(struct llama_sampler * smpl) {
452
- GGML_UNUSED(smpl);
453
- }
454
-
455
- static struct llama_sampler * llama_sampler_empty_clone(const struct llama_sampler * smpl) {
456
- auto * ctx = (llama_sampler_empty *) smpl->ctx;
457
- return llama_sampler_init_empty(ctx->name);
458
- }
459
-
460
- static void llama_sampler_empty_free(struct llama_sampler * smpl) {
461
- delete (llama_sampler_empty *) smpl->ctx;
462
- }
463
-
464
- static bool llama_sampler_empty_backend_init(
465
- struct llama_sampler * smpl,
466
- ggml_backend_buffer_type_t buft) {
467
- GGML_UNUSED(smpl);
468
- GGML_UNUSED(buft);
469
-
470
- return true;
471
- }
472
-
473
- static void llama_sampler_empty_backend_accept(
474
- struct llama_sampler * smpl,
475
- ggml_context * ctx,
476
- ggml_cgraph * gf,
477
- struct ggml_tensor * selected_token) {
478
- GGML_UNUSED(smpl);
479
- GGML_UNUSED(ctx);
480
- GGML_UNUSED(gf);
481
- GGML_UNUSED(selected_token);
482
- }
483
-
484
- static void llama_sampler_empty_backend_apply(
485
- struct llama_sampler * smpl,
486
- struct ggml_context * ctx,
487
- struct ggml_cgraph * gf,
488
- struct llama_sampler_data * data) {
489
- GGML_UNUSED(smpl);
490
- GGML_UNUSED(ctx);
491
- GGML_UNUSED(gf);
492
- GGML_UNUSED(data);
493
- }
494
-
495
- static void llama_sampler_empty_backend_set_input(struct llama_sampler * smpl) {
496
- GGML_UNUSED(smpl);
497
- }
498
-
499
- static struct llama_sampler_i llama_sampler_empty_i = {
500
- /* .name = */ llama_sampler_empty_name,
501
- /* .accept = */ llama_sampler_empty_accept,
502
- /* .apply = */ llama_sampler_empty_apply,
503
- /* .reset = */ llama_sampler_empty_reset,
504
- /* .clone = */ llama_sampler_empty_clone,
505
- /* .free = */ llama_sampler_empty_free,
506
- /* .backend_init = */ llama_sampler_empty_backend_init,
507
- /* .backend_accept = */ llama_sampler_empty_backend_accept,
508
- /* .backend_apply = */ llama_sampler_empty_backend_apply,
509
- /* .backend_set_input = */ llama_sampler_empty_backend_set_input,
510
- };
511
-
512
- struct llama_sampler * llama_sampler_init_empty(const char * name) {
513
- return llama_sampler_init(
514
- /* .iface = */ &llama_sampler_empty_i,
515
- /* .ctx = */ new llama_sampler_empty {
516
- /* .name = */ name,
517
- }
518
- );
519
- }
520
-
521
- // common backend sampler functionality
522
- //
523
- // +name : means that the sampler is support and will run on the backend
524
- // -name : means that a ggml operator is not supported by the backend
525
- //
526
- struct llama_sampler_backend {
527
- llama_sampler_backend(const char * name) : name(name), name_ext(name), is_init(false), support(false) {}
528
-
529
- const char * get_name() {
530
- if (!is_init) {
531
- return name.c_str();
532
- }
533
-
534
- if (support) {
535
- name_ext = "+" + name;
536
- } else {
537
- name_ext = "-" + name;
538
- }
539
-
540
- return name_ext.c_str();
541
- }
542
-
543
- void init(bool support) {
544
- GGML_ASSERT(this->is_init == false);
545
-
546
- this->is_init = true;
547
- this->support = support;
548
- }
549
-
550
- private:
551
- std::string name;
552
- std::string name_ext;
553
-
554
- bool is_init;
555
- bool support;
556
- };
557
-
558
- // check if all ggml ops used by the sampler are supported by the backend
559
- static bool llama_sampler_backend_support(
560
- llama_sampler * smpl,
561
- ggml_backend_buffer_type_t buft) {
562
- auto * device = ggml_backend_buft_get_device(buft);
563
- if (!device) {
564
- // CPU backend always supported
565
- return true;
566
- }
567
-
568
- ggml_init_params params = {
569
- /*.mem_size =*/ 128*ggml_tensor_overhead() + ggml_graph_overhead(),
570
- /*.mem_buffer =*/ NULL,
571
- /*.no_alloc =*/ true,
572
- };
573
-
574
- ggml_context_ptr ctx_ptr { ggml_init(params) };
575
- if (!ctx_ptr) {
576
- throw std::runtime_error(format("failed to create ggml context"));
577
- }
578
-
579
- ggml_context * ctx = ctx_ptr.get();
580
-
581
- const int64_t n = 1024*1024;
582
-
583
- llama_sampler_data data = {
584
- /*.logits = */ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n),
585
- /*.probs = */ nullptr,
586
- /*.sampled = */ nullptr,
587
- /*.candidates = */ ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n),
588
- };
589
-
590
- ggml_cgraph * gf = ggml_new_graph(ctx);
591
-
592
- smpl->iface->backend_apply(smpl, ctx, gf, &data);
593
-
594
- if (data.logits) {
595
- ggml_build_forward_expand(gf, data.logits);
596
- }
597
-
598
- if (data.probs) {
599
- ggml_build_forward_expand(gf, data.probs);
600
- }
601
-
602
- if (data.sampled) {
603
- ggml_build_forward_expand(gf, data.sampled);
604
- }
605
-
606
- if (data.candidates) {
607
- ggml_build_forward_expand(gf, data.candidates);
608
- }
609
-
610
- for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
611
- struct ggml_tensor * op = ggml_graph_node(gf, i);
612
-
613
- if (!ggml_backend_dev_supports_op(device, op)) {
614
- LLAMA_LOG_WARN("%s: device '%s' does not have support for op %s needed for sampler '%s'\n",
615
- __func__, ggml_backend_dev_name(device), ggml_op_name(op->op), smpl->iface->name(smpl));
616
-
617
- return false;
618
- }
619
- }
620
-
621
- return true;
622
- }
623
-
624
- // sampler chain
625
-
626
- static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
627
- return "chain";
628
- }
629
-
630
- static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) {
631
- auto * chain = (llama_sampler_chain *) smpl->ctx;
632
-
633
- time_meas tm(chain->t_sample_us, chain->params.no_perf);
634
-
635
- for (auto & smpl : chain->samplers) {
636
- llama_sampler_accept(smpl.ptr, token);
637
- }
638
-
639
- chain->n_sample++;
640
- }
641
-
642
- static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
643
- auto * chain = (llama_sampler_chain *) smpl->ctx;
644
-
645
- time_meas tm(chain->t_sample_us, chain->params.no_perf);
646
-
647
- bool is_backend = chain->is_init;
648
-
649
- for (auto & smpl : chain->samplers) {
650
- if (is_backend && smpl.is_backend) {
651
- continue;
652
- }
653
-
654
- is_backend = false;
655
-
656
- if (smpl.ptr->iface->apply == nullptr) {
657
- continue;
658
- }
659
-
660
- llama_sampler_apply(smpl.ptr, cur_p);
661
- }
662
- }
663
-
664
- static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
665
- auto * chain = (llama_sampler_chain *) smpl->ctx;
666
-
667
- for (auto & smpl : chain->samplers) {
668
- llama_sampler_reset(smpl.ptr);
669
- }
670
- }
671
-
672
- static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
673
- const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
674
-
675
- auto * result = llama_sampler_chain_init(chain_src->params);
676
-
677
- for (const auto & smpl : chain_src->samplers) {
678
- llama_sampler_chain_add(result, llama_sampler_clone(smpl.ptr));
679
- }
680
-
681
- return result;
682
- }
683
-
684
- static void llama_sampler_chain_free(struct llama_sampler * smpl) {
685
- auto * chain = (llama_sampler_chain *) smpl->ctx;
686
-
687
- for (auto & smpl : chain->samplers) {
688
- llama_sampler_free(smpl.ptr);
689
- }
690
-
691
- delete chain;
692
- }
693
-
694
- static bool llama_sampler_chain_backend_init(
695
- struct llama_sampler * smpl,
696
- ggml_backend_buffer_type_t buft) {
697
- auto * chain = (llama_sampler_chain *) smpl->ctx;
698
-
699
- GGML_ASSERT(chain->is_init == false && "llama_sampler_chain_backend_init() called twice");
700
-
701
- chain->is_init = true;
702
-
703
- bool res = true;
704
-
705
- for (auto & smpl : chain->samplers) {
706
- bool res_cur = true;
707
-
708
- // to be able to run a sampler on the backend, it has to:
709
- // - have the .backend_init() API implemented
710
- // - return true during .backend_init()
711
- if (smpl.ptr->iface->backend_init) {
712
- if (!smpl.ptr->iface->backend_init(smpl.ptr, buft)) {
713
- res_cur = false;
714
- }
715
- } else {
716
- res_cur = false;
717
- }
718
-
719
- smpl.is_backend = res_cur;
720
-
721
- res = res && res_cur;
722
- }
723
-
724
- return res;
725
- }
726
-
727
- static void llama_sampler_chain_backend_accept(
728
- struct llama_sampler * smpl,
729
- ggml_context * ctx,
730
- ggml_cgraph * gf,
731
- struct ggml_tensor * selected_token) {
732
- auto * chain = (llama_sampler_chain *) smpl->ctx;
733
-
734
- for (auto & smpl : chain->samplers) {
735
- if (!smpl.is_backend) {
736
- break;
737
- }
738
-
739
- if (smpl.ptr->iface->backend_accept) {
740
- smpl.ptr->iface->backend_accept(smpl.ptr, ctx, gf, selected_token);
741
- }
742
- }
743
- }
744
-
745
- static void llama_sampler_chain_backend_apply(
746
- struct llama_sampler * smpl,
747
- struct ggml_context * ctx,
748
- struct ggml_cgraph * gf,
749
- struct llama_sampler_data * data) {
750
- auto * chain = (llama_sampler_chain *) smpl->ctx;
751
-
752
- GGML_ASSERT(chain->is_init && "llama_sampler_chain_backend_init() not called");
753
-
754
- for (auto & smpl : chain->samplers) {
755
- if (!smpl.is_backend) {
756
- break;
757
- }
758
-
759
- if (smpl.ptr->iface->backend_apply) {
760
- smpl.ptr->iface->backend_apply(smpl.ptr, ctx, gf, data);
761
- }
762
- }
763
- }
764
-
765
- static void llama_sampler_chain_backend_set_input(struct llama_sampler * smpl) {
766
- auto * chain = (llama_sampler_chain *) smpl->ctx;
767
-
768
- for (auto & smpl : chain->samplers) {
769
- if (!smpl.is_backend) {
770
- break;
771
- }
772
-
773
- if (smpl.ptr->iface->backend_set_input) {
774
- smpl.ptr->iface->backend_set_input(smpl.ptr);
775
- }
776
- }
777
- }
778
-
779
- static struct llama_sampler_i llama_sampler_chain_i = {
780
- /* .name = */ llama_sampler_chain_name,
781
- /* .accept = */ llama_sampler_chain_accept,
782
- /* .apply = */ llama_sampler_chain_apply,
783
- /* .reset = */ llama_sampler_chain_reset,
784
- /* .clone = */ llama_sampler_chain_clone,
785
- /* .free = */ llama_sampler_chain_free,
786
- /* .backend_init = */ llama_sampler_chain_backend_init,
787
- /* .backend_accept = */ llama_sampler_chain_backend_accept,
788
- /* .backend_apply = */ llama_sampler_chain_backend_apply,
789
- /* .backend_set_input = */ llama_sampler_chain_backend_set_input,
790
- };
791
-
792
- struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
793
- return llama_sampler_init(
794
- /* .iface = */ &llama_sampler_chain_i,
795
- /* .ctx = */ new llama_sampler_chain {
796
- /* .params = */ params,
797
- /* .is_init = */ false,
798
- /* .samplers = */ {},
799
- /* .cur = */ {},
800
- /* .t_sample_us = */ 0,
801
- /* .n_sample = */ 0,
802
- }
803
- );
804
- }
805
-
806
- llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
807
- const llama_token sampled_token = llama_get_sampled_token_ith (ctx, idx);
808
- const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx);
809
- const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx);
810
- const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx);
811
-
812
- // If a backend sampler has already sampled a token, return it.
813
- if (sampled_token != LLAMA_TOKEN_NULL) {
814
- LLAMA_LOG_DEBUG("%s: Backend sampler selected token for idx %d. Skipping CPU samplers\n", __func__, idx);
815
- return sampled_token;
816
- }
817
-
818
- const llama_model * model = llama_get_model(ctx);
819
- const llama_vocab * vocab = llama_model_get_vocab(model);
820
-
821
- const int n_vocab = llama_vocab_n_tokens(vocab);
822
-
823
- // use pre-allocated buffer from chain if available, otherwise allocate locally
824
- std::vector<llama_token_data> * cur_ptr;
825
- std::vector<llama_token_data> cur_local;
826
-
827
- if (smpl->iface == &llama_sampler_chain_i) {
828
- auto * chain = (llama_sampler_chain *) smpl->ctx;
829
- cur_ptr = &chain->cur;
830
- } else {
831
- cur_ptr = &cur_local;
832
- }
833
-
834
- auto & cur = *cur_ptr;
835
-
836
- if (sampled_probs) {
837
- const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx);
838
- cur.resize(sampled_probs_count);
839
- for (uint32_t i = 0; i < sampled_probs_count; ++i) {
840
- cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]};
841
- }
842
- } else if (sampled_logits) {
843
- const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx);
844
- cur.resize(sampled_logits_count);
845
- for (llama_token i = 0; i < (int)sampled_logits_count; i++) {
846
- cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f};
847
- }
848
- } else {
849
- const auto * logits = llama_get_logits_ith(ctx, idx);
850
- GGML_ASSERT(logits != nullptr);
851
- cur.resize(n_vocab);
852
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
853
- cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
854
- }
855
- }
856
-
857
- llama_token_data_array cur_p = {
858
- /* .data = */ cur.data(),
859
- /* .size = */ cur.size(),
860
- /* .selected = */ -1,
861
- /* .sorted = */ false,
862
- };
863
-
864
- llama_sampler_apply(smpl, &cur_p);
865
-
866
- GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
867
-
868
- auto token = cur_p.data[cur_p.selected].id;
869
-
870
- llama_sampler_accept(smpl, token);
871
-
872
- return token;
873
- }
874
-
875
-
876
- void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
877
- auto * p = (llama_sampler_chain *) chain->ctx;
878
- p->samplers.push_back({
879
- /* .is_backend = */ false,
880
- /* .ptr = */ smpl,
881
- });
882
- }
883
-
884
- struct llama_sampler * llama_sampler_chain_get(struct llama_sampler * chain, int32_t i) {
885
- if (chain == nullptr) {
886
- return nullptr;
887
- }
888
-
889
- if (chain->iface != &llama_sampler_chain_i) {
890
- return nullptr;
891
- }
892
-
893
- if (i == -1) {
894
- return chain;
895
- }
896
-
897
- const auto * p = (const llama_sampler_chain *) chain->ctx;
898
-
899
- if (i < 0 || (size_t) i >= p->samplers.size()) {
900
- return nullptr;
901
- }
902
-
903
- return p->samplers[i].ptr;
904
- }
905
-
906
- struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
907
- auto * p = (llama_sampler_chain *) chain->ctx;
908
-
909
- if (i < 0 || (size_t) i >= p->samplers.size()) {
910
- return nullptr;
911
- }
912
-
913
- auto * result = p->samplers[i].ptr;
914
- p->samplers.erase(p->samplers.begin() + i);
915
-
916
- return result;
917
- }
918
-
919
- int llama_sampler_chain_n(const struct llama_sampler * chain) {
920
- const auto * p = (const llama_sampler_chain *) chain->ctx;
921
-
922
- return p->samplers.size();
923
- }
924
-
925
- //
926
- // samplers
927
- //
928
-
929
- // greedy
930
-
931
- struct llama_sampler_greedy : public llama_sampler_backend {
932
- };
933
-
934
- static const char * llama_sampler_greedy_name(const struct llama_sampler * smpl) {
935
- auto * sctx = (llama_sampler_greedy *) smpl->ctx;
936
- return sctx->get_name();
937
- }
938
-
939
- static void llama_sampler_greedy_reset(struct llama_sampler * smpl) {
940
- auto * ctx = (llama_sampler_greedy *) smpl->ctx;
941
- GGML_UNUSED(ctx);
942
- }
943
-
944
- static struct llama_sampler * llama_sampler_greedy_clone(const struct llama_sampler * smpl) {
945
- const auto * ctx = (const llama_sampler_greedy *) smpl->ctx;
946
- auto * result = llama_sampler_init_greedy();
947
-
948
- // copy the state
949
- {
950
- auto * result_ctx = (llama_sampler_greedy *) result->ctx;
951
-
952
- GGML_UNUSED(ctx);
953
- GGML_UNUSED(result_ctx);
954
- }
955
-
956
- return result;
957
- }
958
-
959
- static void llama_sampler_greedy_free(struct llama_sampler * smpl) {
960
- delete (llama_sampler_greedy *) smpl->ctx;
961
- }
962
-
963
- static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
964
- cur_p->selected = 0;
965
- for (size_t i = 1; i < cur_p->size; ++i) {
966
- if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) {
967
- cur_p->selected = i;
968
- }
969
- }
970
- }
971
-
972
- static bool llama_sampler_greedy_backend_init(
973
- struct llama_sampler * smpl,
974
- ggml_backend_buffer_type_t buft) {
975
- auto * sctx = (llama_sampler_greedy *) smpl->ctx;
976
-
977
- const bool res = llama_sampler_backend_support(smpl, buft);
978
-
979
- sctx->init(res);
980
-
981
- return res;
982
- }
983
-
984
- static void llama_sampler_greedy_backend_apply(
985
- struct llama_sampler * smpl,
986
- struct ggml_context * ctx,
987
- struct ggml_cgraph * gf,
988
- struct llama_sampler_data * data) {
989
- GGML_UNUSED(gf);
990
- GGML_UNUSED(smpl);
991
-
992
- struct ggml_tensor * curl = ggml_argmax(ctx, data->logits);
993
- ggml_set_name(curl, "greedy_argmax");
994
-
995
- data->sampled = curl;
996
- }
997
-
998
- static struct llama_sampler_i llama_sampler_greedy_i = {
999
- /* .name = */ llama_sampler_greedy_name,
1000
- /* .accept = */ nullptr,
1001
- /* .apply = */ llama_sampler_greedy_apply,
1002
- /* .reset = */ llama_sampler_greedy_reset,
1003
- /* .clone = */ llama_sampler_greedy_clone,
1004
- /* .free = */ llama_sampler_greedy_free,
1005
- /* .backend_init = */ llama_sampler_greedy_backend_init,
1006
- /* .backend_accept = */ nullptr,
1007
- /* .backend_apply = */ llama_sampler_greedy_backend_apply,
1008
- /* .backend_set_input = */ nullptr,
1009
- };
1010
-
1011
- struct llama_sampler * llama_sampler_init_greedy() {
1012
- return llama_sampler_init(
1013
- /* .iface = */ &llama_sampler_greedy_i,
1014
- /* .ctx = */ new llama_sampler_greedy {
1015
- ("greedy"),
1016
- }
1017
- );
1018
- }
1019
-
1020
- // dist
1021
-
1022
- struct llama_sampler_dist : public llama_sampler_backend {
1023
- const uint32_t seed;
1024
- uint32_t seed_cur;
1025
-
1026
- std::mt19937 rng;
1027
-
1028
- // backend input
1029
- struct ggml_tensor * inp_uniform;
1030
-
1031
- ggml_context_ptr inp_ctx;
1032
- ggml_backend_buffer_ptr inp_buf;
1033
- };
1034
-
1035
- static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) {
1036
- auto * sctx = (llama_sampler_dist *) smpl->ctx;
1037
- return sctx->get_name();
1038
- }
1039
-
1040
- static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1041
- auto * ctx = (llama_sampler_dist *) smpl->ctx;
1042
-
1043
- // edge cases
1044
- if (cur_p->size == 0) {
1045
- cur_p->selected = -1;
1046
- return;
1047
- }
1048
-
1049
- cur_p->selected = 0;
1050
-
1051
- if (cur_p->size == 1) {
1052
- cur_p->data[0].p = 1.0f;
1053
- return;
1054
- }
1055
-
1056
- // max logit for numerical stability
1057
- float max_l = cur_p->data[0].logit;
1058
- if (!cur_p->sorted) {
1059
- for (size_t i = 1; i < cur_p->size; ++i) {
1060
- max_l = std::max(max_l, cur_p->data[i].logit);
1061
- }
1062
- }
1063
-
1064
- // apply softmax to obtain the probabilities
1065
- double sum_cum = 0.0f;
1066
- for (size_t i = 0; i < cur_p->size; ++i) {
1067
- float p = expf(cur_p->data[i].logit - max_l);
1068
- cur_p->data[i].p = p;
1069
- sum_cum += p;
1070
- }
1071
-
1072
- #if 1
1073
- // sample from the obtained probabilities and normalize the probs in a single pass
1074
- // this is ~3x faster on Mac with full gpt-oss vocab than the version below
1075
- //
1076
- std::uniform_real_distribution<double> dist(0.0f, 1.0f);
1077
- const double rnd = dist(ctx->rng);
1078
-
1079
- double sum_run = 0.0f;
1080
- const double sum_tgt = sum_cum*rnd;
1081
-
1082
- bool found = false;
1083
- for (size_t i = 0; i < cur_p->size; ++i) {
1084
- if (!found) {
1085
- // accumulate probs until we reach the target sum
1086
- sum_run += cur_p->data[i].p;
1087
- if (sum_run >= sum_tgt) {
1088
- cur_p->selected = i;
1089
- found = true;
1090
- }
1091
- }
1092
-
1093
- // normalize probs
1094
- cur_p->data[i].p /= sum_cum;
1095
- }
1096
-
1097
- // fallback to the last token (don't think this can happen)
1098
- assert(found);
1099
- if (!found) {
1100
- cur_p->selected = cur_p->size - 1;
1101
- }
1102
- #else
1103
- // for clarity, this is the same as above but does one pass for normalization and one extra pass for sampling
1104
- for (size_t i = 0; i < cur_p->size; ++i) {
1105
- cur_p->data[i].p /= sum_cum;
1106
- }
1107
-
1108
- cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
1109
- #endif
1110
- }
1111
-
1112
- static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
1113
- auto * ctx = (llama_sampler_dist *) smpl->ctx;
1114
- ctx->seed_cur = get_rng_seed(ctx->seed);
1115
- ctx->rng.seed(ctx->seed_cur);
1116
- }
1117
-
1118
- static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
1119
- const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
1120
- auto * result = llama_sampler_init_dist(ctx->seed);
1121
-
1122
- // copy the state
1123
- {
1124
- auto * result_ctx = (llama_sampler_dist *) result->ctx;
1125
-
1126
- result_ctx->rng = ctx->rng;
1127
- }
1128
-
1129
- return result;
1130
- }
1131
-
1132
- static void llama_sampler_dist_free(struct llama_sampler * smpl) {
1133
- delete (llama_sampler_dist *) smpl->ctx;
1134
- }
1135
-
1136
- static bool llama_sampler_dist_backend_init(
1137
- struct llama_sampler * smpl,
1138
- ggml_backend_buffer_type_t buft) {
1139
- auto * sctx = (llama_sampler_dist *) smpl->ctx;
1140
-
1141
- // allocate inputs
1142
- {
1143
- ggml_init_params params = {
1144
- /*.mem_size =*/ ggml_tensor_overhead(),
1145
- /*.mem_buffer =*/ nullptr,
1146
- /*.no_alloc =*/ true,
1147
- };
1148
-
1149
- sctx->inp_ctx.reset(ggml_init(params));
1150
-
1151
- // Create the uniform random scalar input tensor. This will be set by
1152
- // llama_sampler_dist_backend_set_input after this graph is built.
1153
- sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1);
1154
- ggml_set_name (sctx->inp_uniform, "uniform");
1155
- ggml_set_input(sctx->inp_uniform);
1156
-
1157
- // Allocate all tensors from our context to the backend
1158
- sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
1159
-
1160
- ggml_backend_buffer_clear(sctx->inp_buf.get(), 0);
1161
- }
1162
-
1163
- const bool res = llama_sampler_backend_support(smpl, buft);
1164
-
1165
- sctx->init(res);
1166
-
1167
- if (!res) {
1168
- sctx->inp_ctx.reset(nullptr);
1169
- sctx->inp_buf.reset(nullptr);
1170
- }
1171
-
1172
- return res;
1173
- }
1174
-
1175
- static void llama_sampler_dist_backend_apply(
1176
- struct llama_sampler * smpl,
1177
- struct ggml_context * ctx,
1178
- struct ggml_cgraph * gf,
1179
- struct llama_sampler_data * data) {
1180
- GGML_UNUSED(gf);
1181
- auto * sctx = (llama_sampler_dist *) smpl->ctx;
1182
-
1183
- struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
1184
- ggml_set_name(probs, "dist_probs");
1185
-
1186
- struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs);
1187
- ggml_set_name(cumsum, "dist_cumsum");
1188
-
1189
- // The uniform tensor has a random value and we subtract this tensor with
1190
- // the cumsum tensor (the uniform tensor will be broadcasted by ggml_sub).
1191
- // Recall that each entry in cumsum is the cumulative probability up to that
1192
- // index so values stay negative while the cumulative total is below the
1193
- // random value, and become zero/positive once the threshold is crossed.
1194
- struct ggml_tensor * diff = ggml_sub(ctx, cumsum, sctx->inp_uniform);
1195
- ggml_set_name(diff, "dist_cumsum");
1196
-
1197
- // The ggml_step function produces a tensor where entries are 1 if the
1198
- // corresponding entry in diff is > 0, and 0 otherwise. So all values up to
1199
- // the index where the cumulative probability exceeds the random value are 0,
1200
- // and all entries after that are 1.
1201
- struct ggml_tensor * mask = ggml_step(ctx, diff);
1202
- ggml_set_name(mask, "dist_mask");
1203
-
1204
- // Taking the sum of the mask gives us the sum of elements after the threshold
1205
- // we are interested in.
1206
- struct ggml_tensor * idxf = ggml_sum(ctx, mask);
1207
- ggml_set_name(idxf, "dist_index_f32");
1208
-
1209
- // Use ggml_scale_bias to scale the index value by -1 and then add the size
1210
- // of the mask to that value so we get the correct index ((-1 * idxf) + n).
1211
- struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32);
1212
- ggml_set_name(idx, "dist_index_i32");
1213
-
1214
- // Map back to original vocab ids if a candidates tensor is available.
1215
- struct ggml_tensor * sampled_token = idx;
1216
- if (data->candidates != nullptr) {
1217
- struct ggml_tensor * candidates = ggml_reshape_2d(ctx, data->candidates, 1, ggml_nelements(data->candidates));
1218
-
1219
- sampled_token = ggml_get_rows(ctx, candidates, idx);
1220
- ggml_set_name(sampled_token, "dist_sampled_token");
1221
- }
1222
-
1223
- data->sampled = sampled_token;
1224
- data->probs = probs;
1225
- }
1226
-
1227
- static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) {
1228
- auto * sctx = (llama_sampler_dist *) smpl->ctx;
1229
- GGML_ASSERT(sctx->inp_uniform != nullptr);
1230
-
1231
- // We sample in double precision and cast to float to match rnd numbers of
1232
- // llama_dampler_dist which uses double precision (sampling from
1233
- // std::uniform_real_distribution<double> and
1234
- // std::uniform_real_distribution<float> with same rng will produce
1235
- // different sequences).
1236
- std::uniform_real_distribution<double> dist(0.0f, 1.0f);
1237
- const float rnd = dist(sctx->rng);
1238
-
1239
- ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float));
1240
- }
1241
-
1242
- static struct llama_sampler_i llama_sampler_dist_i = {
1243
- /* .name = */ llama_sampler_dist_name,
1244
- /* .accept = */ nullptr,
1245
- /* .apply = */ llama_sampler_dist_apply,
1246
- /* .reset = */ llama_sampler_dist_reset,
1247
- /* .clone = */ llama_sampler_dist_clone,
1248
- /* .free = */ llama_sampler_dist_free,
1249
- /* .backend_init = */ llama_sampler_dist_backend_init,
1250
- /* .backend_accept = */ nullptr,
1251
- /* .backend_apply = */ llama_sampler_dist_backend_apply,
1252
- /* .backend_set_input = */ llama_sampler_dist_backend_set_input,
1253
- };
1254
-
1255
- struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
1256
- auto seed_cur = get_rng_seed(seed);
1257
- return llama_sampler_init(
1258
- /* .iface = */ &llama_sampler_dist_i,
1259
- /* .ctx = */ new llama_sampler_dist {
1260
- ("dist"),
1261
- /* .seed = */ seed,
1262
- /* .seed_cur = */ seed_cur,
1263
- /* .rng = */ std::mt19937(seed_cur),
1264
- /* .inp_uniform = */ nullptr,
1265
- /* .inp_ctx = */ nullptr,
1266
- /* .inp_buf = */ nullptr,
1267
- }
1268
- );
1269
- }
1270
-
1271
- // top-k
1272
-
1273
- struct llama_sampler_top_k : public llama_sampler_backend {
1274
- const int32_t k;
1275
- };
1276
-
1277
- static const char * llama_sampler_top_k_name(const struct llama_sampler * smpl) {
1278
- auto * sctx = (llama_sampler_top_k *) smpl->ctx;
1279
- return sctx->get_name();
1280
- }
1281
-
1282
- static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1283
- auto * ctx = (llama_sampler_top_k *) smpl->ctx;
1284
- llama_sampler_top_k_impl(cur_p, ctx->k);
1285
- }
1286
-
1287
- static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) {
1288
- const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
1289
- return llama_sampler_init_top_k(ctx->k);
1290
- }
1291
-
1292
- static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
1293
- delete (llama_sampler_top_k *) smpl->ctx;
1294
- }
1295
-
1296
- static bool llama_sampler_top_k_backend_init(
1297
- struct llama_sampler * smpl,
1298
- ggml_backend_buffer_type_t buft) {
1299
- auto * sctx = (llama_sampler_top_k *) smpl->ctx;
1300
-
1301
- const bool res = llama_sampler_backend_support(smpl, buft);
1302
-
1303
- sctx->init(res);
1304
-
1305
- return res;
1306
- }
1307
-
1308
- static void llama_sampler_top_k_backend_apply(
1309
- struct llama_sampler * smpl,
1310
- struct ggml_context * ctx,
1311
- struct ggml_cgraph * gf,
1312
- struct llama_sampler_data * data) {
1313
- auto * sctx = (llama_sampler_top_k *) smpl->ctx;
1314
-
1315
- struct ggml_tensor * top_k = ggml_top_k(ctx, data->logits, sctx->k);
1316
- ggml_set_name(top_k, "top_k");
1317
-
1318
- if (data->candidates) {
1319
- struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]);
1320
- data->candidates = ggml_get_rows(ctx, candidates_rows, top_k);
1321
- data->candidates = ggml_reshape_1d(ctx, data->candidates, sctx->k);
1322
- ggml_set_name(data->candidates, "top_k_candidates");
1323
- } else {
1324
- data->candidates = top_k;
1325
- }
1326
-
1327
- struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
1328
- struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k);
1329
- data->logits = ggml_reshape_1d(ctx, top_k_rows, sctx->k);
1330
- ggml_set_name(top_k_rows, "top_k_rows");
1331
-
1332
- GGML_UNUSED(gf);
1333
- }
1334
-
1335
- static struct llama_sampler_i llama_sampler_top_k_i = {
1336
- /* .name = */ llama_sampler_top_k_name,
1337
- /* .accept = */ nullptr,
1338
- /* .apply = */ llama_sampler_top_k_apply,
1339
- /* .reset = */ nullptr,
1340
- /* .clone = */ llama_sampler_top_k_clone,
1341
- /* .free = */ llama_sampler_top_k_free,
1342
- /* .backend_init = */ llama_sampler_top_k_backend_init,
1343
- /* .backend_accept = */ nullptr,
1344
- /* .backend_apply = */ llama_sampler_top_k_backend_apply,
1345
- /* .backend_set_input = */ nullptr,
1346
- };
1347
-
1348
- struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
1349
- const bool is_empty = (k <= 0);
1350
-
1351
- if (is_empty) {
1352
- return llama_sampler_init_empty("?top-k");
1353
- }
1354
-
1355
- return llama_sampler_init(
1356
- /* .iface = */ &llama_sampler_top_k_i,
1357
- /* .ctx = */ new llama_sampler_top_k {
1358
- ("top-k"),
1359
- /* .k = */ k,
1360
- }
1361
- );
1362
- }
1363
-
1364
- // top-p
1365
-
1366
- struct llama_sampler_top_p : public llama_sampler_backend {
1367
- const float p;
1368
- const size_t min_keep;
1369
-
1370
- std::vector<llama_token_data> buf_sort;
1371
- };
1372
-
1373
- static const char * llama_sampler_top_p_name(const struct llama_sampler * smpl) {
1374
- auto * sctx = (llama_sampler_top_p *) smpl->ctx;
1375
- return sctx->get_name();
1376
- }
1377
-
1378
- static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1379
- auto * ctx = (llama_sampler_top_p *) smpl->ctx;
1380
-
1381
- if (ctx->p >= 1.0f) {
1382
- return;
1383
- }
1384
-
1385
- llama_sampler_softmax_impl(cur_p, false);
1386
-
1387
- size_t k = cur_p->size;
1388
- auto * pdata = cur_p->data;
1389
-
1390
- auto & buf_sort = ctx->buf_sort;
1391
-
1392
- // if not sorted, try adaptive top-k sorting
1393
- if (!cur_p->sorted && cur_p->size > 1024) {
1394
- k = std::min<size_t>(256, cur_p->size);
1395
- llama_token_data_array_partial_sort(*cur_p, k, buf_sort);
1396
- pdata = buf_sort.data();
1397
- } else if (!cur_p->sorted) {
1398
- // small candidates -> sort inplace
1399
- llama_token_data_array_partial_sort_inplace(cur_p, k);
1400
- }
1401
-
1402
- // Compute the cumulative probabilities
1403
- float cum_sum = 0.0f;
1404
- size_t last_idx = cur_p->size;
1405
-
1406
- for (size_t i = 0; i < cur_p->size; ++i) {
1407
- cum_sum += pdata[i].p;
1408
-
1409
- // Check if the running sum is at least p or if we have kept at least min_keep tokens
1410
- // we set the last index to i+1 to indicate that the current iterate should be included in the set
1411
- if (cum_sum >= ctx->p && i + 1 >= ctx->min_keep) {
1412
- last_idx = i + 1;
1413
- break;
1414
- }
1415
-
1416
- // we exceeded the current top-k heuristic -> increase k and continue
1417
- if (!cur_p->sorted && i == k - 1) {
1418
- k = cur_p->size;
1419
- llama_token_data_array_partial_sort(*cur_p, k, buf_sort);
1420
- pdata = buf_sort.data();
1421
- }
1422
- }
1423
-
1424
- // Resize the output vector to keep only the top-p tokens
1425
- if (!cur_p->sorted) {
1426
- std::copy(buf_sort.data(), buf_sort.data() + last_idx, cur_p->data);
1427
- cur_p->sorted = true;
1428
- }
1429
-
1430
- cur_p->size = last_idx;
1431
- }
1432
-
1433
- static struct llama_sampler * llama_sampler_top_p_clone(const struct llama_sampler * smpl) {
1434
- const auto * ctx = (const llama_sampler_top_p *) smpl->ctx;
1435
- return llama_sampler_init_top_p(ctx->p, ctx->min_keep);
1436
- }
1437
-
1438
- static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
1439
- delete (llama_sampler_top_p *) smpl->ctx;
1440
- }
1441
-
1442
- static bool llama_sampler_top_p_backend_init(
1443
- struct llama_sampler * smpl,
1444
- ggml_backend_buffer_type_t buft) {
1445
- auto * sctx = (llama_sampler_top_p *) smpl->ctx;
1446
-
1447
- const bool res = llama_sampler_backend_support(smpl, buft);
1448
-
1449
- sctx->init(res);
1450
-
1451
- return res;
1452
- }
1453
-
1454
- static void llama_sampler_top_p_backend_apply(
1455
- struct llama_sampler * smpl,
1456
- struct ggml_context * ctx,
1457
- struct ggml_cgraph * gf,
1458
- struct llama_sampler_data * data) {
1459
- auto * sctx = (llama_sampler_top_p *) smpl->ctx;
1460
-
1461
- auto ggml_sort = [ctx](struct ggml_tensor * a, struct ggml_tensor * b) {
1462
- GGML_ASSERT(ggml_nrows(a) == 1);
1463
- struct ggml_tensor * a_reshaped = ggml_reshape_2d(ctx, a, 1, a->ne[0]);
1464
- struct ggml_tensor * a_sorted = ggml_get_rows(ctx, a_reshaped, b);
1465
- return ggml_reshape_1d(ctx, a_sorted, a->ne[0]);
1466
- };
1467
-
1468
- // Get the sorted logits in descending order.
1469
- struct ggml_tensor * sorted_idx = ggml_argsort(ctx, data->logits, GGML_SORT_ORDER_DESC);
1470
- ggml_set_name(sorted_idx, "top_p_sorted_idx");
1471
-
1472
- // Do the sorting via reshape + get_rows
1473
- struct ggml_tensor * sorted_logits = ggml_sort(data->logits, sorted_idx);
1474
- ggml_set_name(sorted_logits, "top_p_sorted_logits");
1475
-
1476
- struct ggml_tensor * softmax = ggml_soft_max(ctx, sorted_logits);
1477
- ggml_set_name(softmax, "top_p_softmax");
1478
-
1479
- // If candidates are provided, sort them as well. Otherwise, set sorted indices as candidates.
1480
- if (data->candidates) {
1481
- data->candidates = ggml_sort(data->candidates, sorted_idx);
1482
- } else {
1483
- data->candidates = sorted_idx;
1484
- }
1485
- ggml_set_name(data->candidates, "top_p_candidates");
1486
-
1487
- // Compute Cumulative Distribution Function (CDF) by means of GGML_OP_CUMSUM.
1488
- struct ggml_tensor * cdf = ggml_cumsum(ctx, softmax);
1489
- ggml_set_name(cdf, "top_p_cdf");
1490
-
1491
- // Invert CDF and add top-p value so that ggml_step yields 1 for values we want to keep
1492
- struct ggml_tensor * cdf_scaled = ggml_scale_bias(ctx, cdf, -1.0f, sctx->p);
1493
- ggml_set_name(cdf_scaled, "top_p_cdf_scaled");
1494
-
1495
- struct ggml_tensor * mask = ggml_step(ctx, cdf_scaled);
1496
- ggml_set_name(mask, "top_p_mask");
1497
-
1498
- // Taking the sum of the mask gives us the sum of elements after the threshold
1499
- // we are interested in.
1500
- struct ggml_tensor * idxf = ggml_sum(ctx, mask);
1501
- ggml_set_name(idxf, "top_p_index_f32");
1502
-
1503
- // prevent out-of-bounds access
1504
- idxf = ggml_clamp(ctx, idxf, 0.0f, mask->ne[0] - 1);
1505
-
1506
- // construct ones tensor to set the value in the mask
1507
- struct ggml_tensor * ones = ggml_scale_bias(ctx, idxf, 0.0f, 1.0f);
1508
- ggml_set_name(ones, "top_p_ones");
1509
-
1510
- // Make top-p inclusive (i.e. return all values such that cum_sum/cdf >= p)
1511
- struct ggml_tensor * mask_reshaped = ggml_reshape_2d(ctx, mask, 1, mask->ne[0]);
1512
-
1513
- mask_reshaped = ggml_set_rows(ctx, mask_reshaped, ones, ggml_cast(ctx, idxf, GGML_TYPE_I32));
1514
- mask = ggml_reshape_1d(ctx, mask_reshaped, mask->ne[0]);
1515
-
1516
- // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes:
1517
- // top_p_bias = (mask * 1e9f) - 1e9f.
1518
- // So entries in the mask that we want to discard will become -1e9f, and
1519
- // others will be 0 (meaning that will not effect the logits).
1520
- const float large_val = 1e9f;
1521
- struct ggml_tensor * top_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val);
1522
- ggml_set_name(top_p_bias, "top_p_bias");
1523
-
1524
- data->logits = ggml_add(ctx, sorted_logits, top_p_bias);
1525
- ggml_set_name(data->logits, "top_p_logits");
1526
-
1527
- GGML_UNUSED(gf);
1528
- }
1529
-
1530
- static struct llama_sampler_i llama_sampler_top_p_i = {
1531
- /* .name = */ llama_sampler_top_p_name,
1532
- /* .accept = */ nullptr,
1533
- /* .apply = */ llama_sampler_top_p_apply,
1534
- /* .reset = */ nullptr,
1535
- /* .clone = */ llama_sampler_top_p_clone,
1536
- /* .free = */ llama_sampler_top_p_free,
1537
- /* .backend_init = */ llama_sampler_top_p_backend_init,
1538
- /* .backend_accept = */ nullptr,
1539
- /* .backend_apply = */ llama_sampler_top_p_backend_apply,
1540
- /* .backend_set_input = */ nullptr,
1541
- };
1542
-
1543
- struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
1544
- const bool is_empty = p >= 1.0f;
1545
-
1546
- if (is_empty) {
1547
- return llama_sampler_init_empty("?top-p");
1548
- }
1549
-
1550
- return llama_sampler_init(
1551
- /* .iface = */ &llama_sampler_top_p_i,
1552
- /* .ctx = */ new llama_sampler_top_p {
1553
- ("top-p"),
1554
- /* .p = */ p,
1555
- /* .min_keep = */ min_keep,
1556
- /* .buf_sort = */ {},
1557
- }
1558
- );
1559
- }
1560
-
1561
- // min-p
1562
-
1563
- struct llama_sampler_min_p : public llama_sampler_backend {
1564
- const float p;
1565
- const size_t min_keep;
1566
- };
1567
-
1568
- static const char * llama_sampler_min_p_name(const struct llama_sampler * smpl) {
1569
- auto * sctx = (llama_sampler_min_p *) smpl->ctx;
1570
- return sctx->get_name();
1571
- }
1572
-
1573
- static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1574
- auto * ctx = (llama_sampler_min_p *) smpl->ctx;
1575
-
1576
- if (ctx->p <= 0.0f || !cur_p->size) {
1577
- return;
1578
- }
1579
-
1580
- bool min_p_applied = false;
1581
-
1582
- // if the cur_p aren't sorted, try the unsorted implementation first
1583
- if (!cur_p->sorted) {
1584
- std::vector<llama_token_data> filtered_tokens;
1585
-
1586
- float max_logit = -FLT_MAX;
1587
- for (size_t i = 0; i < cur_p->size; ++i) {
1588
- max_logit = std::max(max_logit, cur_p->data[i].logit);
1589
- }
1590
- const float min_logit = max_logit + logf(ctx->p); // min logit for p_i >= p * p_max
1591
-
1592
- for (size_t i = 0; i < cur_p->size; ++i) {
1593
- if (cur_p->data[i].logit >= min_logit) {
1594
- filtered_tokens.push_back(cur_p->data[i]);
1595
- }
1596
- }
1597
-
1598
- // if we have enough values the operation was a success
1599
- if (!filtered_tokens.empty() && filtered_tokens.size() >= ctx->min_keep) {
1600
- std::copy(filtered_tokens.begin(), filtered_tokens.end(), cur_p->data);
1601
- cur_p->size = filtered_tokens.size();
1602
- min_p_applied = true;
1603
- }
1604
- }
1605
-
1606
- // if the cur_p are sorted or the unsorted implementation failed, use this implementation
1607
- if (!min_p_applied) {
1608
- // Sort the logits in descending order
1609
- if (!cur_p->sorted) {
1610
- llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size);
1611
- }
1612
-
1613
- const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max
1614
- size_t i = 1; // first token always matches
1615
-
1616
- for (; i < cur_p->size; ++i) {
1617
- if (cur_p->data[i].logit < min_logit && i >= ctx->min_keep) {
1618
- break; // prob too small
1619
- }
1620
- }
1621
-
1622
- // Resize the output vector to keep only the matching tokens
1623
- cur_p->size = i;
1624
- }
1625
- }
1626
-
1627
- static struct llama_sampler * llama_sampler_min_p_clone(const struct llama_sampler * smpl) {
1628
- const auto * ctx = (const llama_sampler_min_p *) smpl->ctx;
1629
- return llama_sampler_init_min_p(ctx->p, ctx->min_keep);
1630
- }
1631
-
1632
- static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
1633
- delete (llama_sampler_min_p *) smpl->ctx;
1634
- }
1635
-
1636
- static bool llama_sampler_min_p_backend_init(
1637
- struct llama_sampler * smpl,
1638
- ggml_backend_buffer_type_t buft) {
1639
- auto * sctx = (llama_sampler_min_p *) smpl->ctx;
1640
-
1641
- const bool res = llama_sampler_backend_support(smpl, buft);
1642
-
1643
- sctx->init(res);
1644
-
1645
- return res;
1646
- }
1647
-
1648
- static void llama_sampler_min_p_backend_apply(
1649
- struct llama_sampler * smpl,
1650
- struct ggml_context * ctx,
1651
- struct ggml_cgraph * gf,
1652
- struct llama_sampler_data * data) {
1653
- auto * sctx = (llama_sampler_min_p *) smpl->ctx;
1654
-
1655
- struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
1656
- ggml_set_name(max_idx, "max_idx");
1657
-
1658
- struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
1659
- ggml_set_name(logits_rows, "logits_rows");
1660
-
1661
- struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_rows, max_idx);
1662
- ggml_set_name(max_logit, "max_logit");
1663
-
1664
- // Calculate the threshold value.
1665
- struct ggml_tensor * threshold = ggml_scale_bias(ctx, max_logit, 1.0f, logf(sctx->p));
1666
- ggml_set_name(threshold, "min_p_threshold");
1667
-
1668
- // Subtract the threshold from logits.
1669
- struct ggml_tensor * sub = ggml_sub(ctx, data->logits, threshold);
1670
-
1671
- // Create a mask where logits below the threshold are 0 (discard),
1672
- // and others are 1 (keep).
1673
- struct ggml_tensor * mask = ggml_step(ctx, sub);
1674
- ggml_set_name(mask, "min_p_mask");
1675
-
1676
- // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes:
1677
- // min_p_bias = (mask * 1e9f) - 1e9f.
1678
- // So entries in the mask that we want to discard will become -1e9f, and
1679
- // others will be 0 (meaning that will not effect the logits).
1680
- const float large_val = 1e9f;
1681
- struct ggml_tensor * min_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val);
1682
- ggml_set_name(min_p_bias, "min_p_bias");
1683
-
1684
- // Add the min_p bias to the logits.
1685
- data->logits = ggml_add(ctx, data->logits, min_p_bias);
1686
- ggml_set_name(data->logits, "min_p_logits");
1687
-
1688
- GGML_UNUSED(gf);
1689
- }
1690
-
1691
- static struct llama_sampler_i llama_sampler_min_p_i = {
1692
- /* .name = */ llama_sampler_min_p_name,
1693
- /* .accept = */ nullptr,
1694
- /* .apply = */ llama_sampler_min_p_apply,
1695
- /* .reset = */ nullptr,
1696
- /* .clone = */ llama_sampler_min_p_clone,
1697
- /* .free = */ llama_sampler_min_p_free,
1698
- /* .backend_init = */ llama_sampler_min_p_backend_init,
1699
- /* .backend_accept = */ nullptr,
1700
- /* .backend_apply = */ llama_sampler_min_p_backend_apply,
1701
- /* .backend_set_input = */ nullptr,
1702
- };
1703
-
1704
- struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
1705
- const bool is_empty = (p <= 0.0f);
1706
-
1707
- if (is_empty) {
1708
- return llama_sampler_init_empty("?min-p");
1709
- }
1710
-
1711
- return llama_sampler_init(
1712
- /* .iface = */ &llama_sampler_min_p_i,
1713
- /* .ctx = */ new llama_sampler_min_p {
1714
- ("min-p"),
1715
- /* .p = */ p,
1716
- /* .min_keep = */ min_keep,
1717
- }
1718
- );
1719
- }
1720
-
1721
- // typical
1722
-
1723
- struct llama_sampler_typical {
1724
- const float p;
1725
- const size_t min_keep;
1726
- };
1727
-
1728
- static const char * llama_sampler_typical_name(const struct llama_sampler * /*smpl*/) {
1729
- return "typical";
1730
- }
1731
-
1732
- static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1733
- auto * ctx = (llama_sampler_typical *) smpl->ctx;
1734
-
1735
- // Reference implementation:
1736
- // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
1737
- if (ctx->p >= 1.0f) {
1738
- return;
1739
- }
1740
-
1741
- // Compute the softmax of logits and calculate entropy
1742
- llama_sampler_softmax_impl(cur_p, true);
1743
-
1744
- float entropy = 0.0f;
1745
- for (size_t i = 0; i < cur_p->size; ++i) {
1746
- entropy += -cur_p->data[i].p * logf(cur_p->data[i].p);
1747
- }
1748
-
1749
- // Compute the absolute difference between negative log probability and entropy for each candidate
1750
- std::vector<float> shifted_scores;
1751
- for (size_t i = 0; i < cur_p->size; ++i) {
1752
- float shifted_score = fabsf(-logf(cur_p->data[i].p) - entropy);
1753
- shifted_scores.push_back(shifted_score);
1754
- }
1755
-
1756
- // Sort tokens based on the shifted_scores and their corresponding indices
1757
- std::vector<size_t> indices(cur_p->size);
1758
- std::iota(indices.begin(), indices.end(), 0);
1759
-
1760
- std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
1761
- return shifted_scores[a] < shifted_scores[b];
1762
- });
1763
-
1764
- // Compute the cumulative probabilities
1765
- float cum_sum = 0.0f;
1766
- size_t last_idx = indices.size();
1767
-
1768
- for (size_t i = 0; i < indices.size(); ++i) {
1769
- size_t idx = indices[i];
1770
- cum_sum += cur_p->data[idx].p;
1771
-
1772
- // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
1773
- if (cum_sum > ctx->p && (ctx->min_keep == 0 || i >= ctx->min_keep - 1)) {
1774
- last_idx = i + 1;
1775
- break;
1776
- }
1777
- }
1778
-
1779
- // Resize the output vector to keep only the locally typical tokens
1780
- std::vector<llama_token_data> cur_p_new;
1781
- for (size_t i = 0; i < last_idx; ++i) {
1782
- size_t idx = indices[i];
1783
- cur_p_new.push_back(cur_p->data[idx]);
1784
- }
1785
-
1786
- // Replace the data in cur_p with the cur_p_new data
1787
- std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data);
1788
- cur_p->size = cur_p_new.size();
1789
- cur_p->sorted = false;
1790
- }
1791
-
1792
- static struct llama_sampler * llama_sampler_typical_clone(const struct llama_sampler * smpl) {
1793
- const auto * ctx = (const llama_sampler_typical *) smpl->ctx;
1794
- return llama_sampler_init_typical(ctx->p, ctx->min_keep);
1795
- }
1796
-
1797
- static void llama_sampler_typical_free(struct llama_sampler * smpl) {
1798
- delete (llama_sampler_typical *) smpl->ctx;
1799
- }
1800
-
1801
- static struct llama_sampler_i llama_sampler_typical_i = {
1802
- /* .name = */ llama_sampler_typical_name,
1803
- /* .accept = */ nullptr,
1804
- /* .apply = */ llama_sampler_typical_apply,
1805
- /* .reset = */ nullptr,
1806
- /* .clone = */ llama_sampler_typical_clone,
1807
- /* .free = */ llama_sampler_typical_free,
1808
- /* .backend_init = */ nullptr,
1809
- /* .backend_accept = */ nullptr,
1810
- /* .backend_apply = */ nullptr,
1811
- /* .backend_set_input = */ nullptr,
1812
- };
1813
-
1814
- struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
1815
- const bool is_empty = (p >= 1.0f);
1816
-
1817
- if (is_empty) {
1818
- return llama_sampler_init_empty("?typical");
1819
- }
1820
-
1821
- return llama_sampler_init(
1822
- /* .iface = */ &llama_sampler_typical_i,
1823
- /* .ctx = */ new llama_sampler_typical {
1824
- /* .p = */ p,
1825
- /* .min_keep = */ min_keep,
1826
- }
1827
- );
1828
- }
1829
-
1830
- // temp
1831
-
1832
- struct llama_sampler_temp : public llama_sampler_backend {
1833
- const float temp;
1834
- };
1835
-
1836
- static const char * llama_sampler_temp_name(const struct llama_sampler * smpl) {
1837
- auto * sctx = (llama_sampler_temp *) smpl->ctx;
1838
- return sctx->get_name();
1839
- }
1840
-
1841
- static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1842
- const auto * ctx = (llama_sampler_temp *) smpl->ctx;
1843
-
1844
- llama_sampler_temp_impl(cur_p, ctx->temp);
1845
- }
1846
-
1847
- static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
1848
- const auto * ctx = (const llama_sampler_temp *) smpl->ctx;
1849
- return llama_sampler_init_temp(ctx->temp);
1850
- }
1851
-
1852
- static void llama_sampler_temp_free(struct llama_sampler * smpl) {
1853
- delete (llama_sampler_temp *) smpl->ctx;
1854
- }
1855
-
1856
- static void llama_sampler_backend_temp_sampling(
1857
- struct ggml_context * ctx,
1858
- struct ggml_cgraph * gf,
1859
- struct llama_sampler_data * data,
1860
- float temp) {
1861
- if (temp <= 0.0f) {
1862
- // Find the most probable token index.
1863
- struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
1864
- ggml_set_name(max_idx, "temp_max_idx");
1865
-
1866
- if (data->candidates) {
1867
- struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]);
1868
- data->candidates = ggml_get_rows(ctx, candidates_rows, max_idx);
1869
- } else {
1870
- data->candidates = max_idx;
1871
- }
1872
-
1873
- struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
1874
- data->logits = ggml_get_rows(ctx, logits_rows, max_idx);
1875
-
1876
- return;
1877
- }
1878
-
1879
- data->logits = ggml_scale(ctx, data->logits, 1.0f / temp);
1880
-
1881
- GGML_UNUSED(gf);
1882
- }
1883
-
1884
- static bool llama_sampler_temp_backend_init(
1885
- struct llama_sampler * smpl,
1886
- ggml_backend_buffer_type_t buft) {
1887
- auto * sctx = (llama_sampler_temp *) smpl->ctx;
1888
-
1889
- const bool res = llama_sampler_backend_support(smpl, buft);
1890
-
1891
- sctx->init(res);
1892
-
1893
- return res;
1894
- }
1895
-
1896
- static void llama_sampler_temp_backend_apply(
1897
- struct llama_sampler * smpl,
1898
- struct ggml_context * ctx,
1899
- struct ggml_cgraph * gf,
1900
- struct llama_sampler_data * data) {
1901
- auto * sctx = (llama_sampler_temp *) smpl->ctx;
1902
- llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp);
1903
- }
1904
-
1905
- static struct llama_sampler_i llama_sampler_temp_i = {
1906
- /* .name = */ llama_sampler_temp_name,
1907
- /* .accept = */ nullptr,
1908
- /* .apply = */ llama_sampler_temp_apply,
1909
- /* .reset = */ nullptr,
1910
- /* .clone = */ llama_sampler_temp_clone,
1911
- /* .free = */ llama_sampler_temp_free,
1912
- /* .backend_init = */ llama_sampler_temp_backend_init,
1913
- /* .backend_accept = */ nullptr,
1914
- /* .backend_apply = */ llama_sampler_temp_backend_apply,
1915
- /* .backend_set_input = */ nullptr,
1916
- };
1917
-
1918
- struct llama_sampler * llama_sampler_init_temp(float temp) {
1919
- const bool is_empty = temp == 1.0f;
1920
-
1921
- if (is_empty) {
1922
- return llama_sampler_init_empty("?temp");
1923
- }
1924
-
1925
- return llama_sampler_init(
1926
- /* .iface = */ &llama_sampler_temp_i,
1927
- /* .ctx = */ new llama_sampler_temp {
1928
- ("temp"),
1929
- /*.temp = */ temp,
1930
- }
1931
- );
1932
- }
1933
-
1934
- // temp-ext
1935
-
1936
- struct llama_sampler_temp_ext : public llama_sampler_backend {
1937
- const float temp;
1938
- const float delta;
1939
- const float exponent;
1940
- };
1941
-
1942
- static const char * llama_sampler_temp_ext_name(const struct llama_sampler * smpl) {
1943
- auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
1944
- return sctx->get_name();
1945
- }
1946
-
1947
- static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1948
- auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
1949
- if (ctx->delta > 0) {
1950
- const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
1951
- const float max_temp = ctx->temp + ctx->delta;
1952
-
1953
- float exponent_val = ctx->exponent;
1954
-
1955
- // no need to do anything if there is only one (or zero) candidates
1956
- if (cur_p->size <= 1) {
1957
- return;
1958
- }
1959
-
1960
- // Calculate maximum possible entropy
1961
- float max_entropy = -logf(1.0f / cur_p->size);
1962
-
1963
- llama_sampler_softmax_impl(cur_p, true);
1964
-
1965
- // Calculate entropy of the softmax probabilities
1966
- float entropy = 0.0f;
1967
- for (size_t i = 0; i < cur_p->size; ++i) {
1968
- float prob = cur_p->data[i].p;
1969
- if (prob > 0.0f) { // Ensure no log(0)
1970
- entropy -= prob * logf(prob);
1971
- }
1972
- }
1973
-
1974
- // Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above)
1975
- float normalized_entropy = entropy / max_entropy;
1976
-
1977
- // Map the normalized entropy to the desired temperature range using the power function
1978
- float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
1979
-
1980
- #ifdef DEBUG
1981
- LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
1982
- LLAMA_LOG_INFO("Entropy: %f\n", entropy);
1983
- LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
1984
- LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
1985
- LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
1986
- LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
1987
- #endif
1988
-
1989
- // Apply the dynamically calculated temperature scaling
1990
- llama_sampler_temp_impl(cur_p, dyn_temp);
1991
-
1992
- // Re-compute softmax probabilities after scaling logits with dynamic temperature
1993
- const double max_l_double = cur_p->data[0].logit;
1994
-
1995
- double cum_sum_double = 0.0;
1996
- for (size_t i = 0; i < cur_p->size; ++i) {
1997
- double p = exp(cur_p->data[i].logit - max_l_double);
1998
- cur_p->data[i].p = p; // Store the scaled probability
1999
- cum_sum_double += p;
2000
- }
2001
-
2002
- for (size_t i = 0; i < cur_p->size; ++i) {
2003
- cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities
2004
- }
2005
-
2006
- #ifdef DEBUG
2007
- // Print the updated top 25 probabilities after temperature scaling
2008
- LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
2009
- for (size_t i = 0; i < 25 && i < cur_p->size; ++i) {
2010
- LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f);
2011
- }
2012
- #endif
2013
- } else {
2014
- llama_sampler_temp_impl(cur_p, ctx->temp);
2015
- }
2016
- }
2017
-
2018
- static struct llama_sampler * llama_sampler_temp_ext_clone(const struct llama_sampler * smpl) {
2019
- const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx;
2020
- return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent);
2021
- }
2022
-
2023
- static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
2024
- delete (llama_sampler_temp_ext *) smpl->ctx;
2025
- }
2026
-
2027
- static bool llama_sampler_temp_ext_backend_init(
2028
- struct llama_sampler * smpl,
2029
- ggml_backend_buffer_type_t buft) {
2030
- auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
2031
-
2032
- const bool res = llama_sampler_backend_support(smpl, buft);
2033
-
2034
- sctx->init(res);
2035
-
2036
- return res;
2037
- }
2038
-
2039
- static void llama_sampler_temp_ext_backend_apply(
2040
- struct llama_sampler * smpl,
2041
- struct ggml_context * ctx,
2042
- struct ggml_cgraph * gf,
2043
- struct llama_sampler_data * data) {
2044
- auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
2045
-
2046
- // Revert to standard temperature scaling if delta or temp are non-positive.
2047
- if (sctx->delta <= 0.0f || sctx->temp <= 0.0f) {
2048
- llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp);
2049
- return;
2050
- }
2051
-
2052
- // Calculate min_temp, max_temp, and max_entropy.
2053
- const float min_temp = std::max(0.0f, sctx->temp - sctx->delta);
2054
- const float max_temp = sctx->temp + sctx->delta;
2055
- const float max_entropy = logf(data->logits->ne[0]);
2056
-
2057
- // Calculate the probabilities.
2058
- struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
2059
- ggml_set_name(probs, "temp_ext_softmax_probs");
2060
-
2061
- // Clamp probabilities to avoid log(0) which would give -inf
2062
- struct ggml_tensor * probs_clamped = ggml_clamp(ctx, probs, 1e-10f, 1.0f);
2063
- ggml_set_name(probs_clamped, "temp_ext_probs_clamped");
2064
-
2065
- // Calculate the entropy, entropy = -Σ(p * log(p)).
2066
- struct ggml_tensor * log_probs = ggml_log(ctx, probs_clamped);
2067
- struct ggml_tensor * p_log_p = ggml_mul(ctx, probs_clamped, log_probs);
2068
- struct ggml_tensor * sum_p_log_p = ggml_sum(ctx, p_log_p);
2069
- struct ggml_tensor * entropy = ggml_scale(ctx, sum_p_log_p, -1.0f);
2070
- ggml_set_name(log_probs, "temp_ext_log_probs");
2071
- ggml_set_name(p_log_p, "temp_ext_p_log_p");
2072
- ggml_set_name(sum_p_log_p, "temp_ext_sum_p_log_p");
2073
- ggml_set_name(entropy, "temp_ext_entropy");
2074
-
2075
- // Normalize the entropy, norm_entropy = entropy / max_entropy
2076
- struct ggml_tensor * norm_entropy = ggml_scale(ctx, entropy, 1.0f / max_entropy);
2077
- ggml_set_name(norm_entropy, "temp_ext_norm_entropy");
2078
-
2079
- // Calculate the dynamic temperature:
2080
- // dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent);
2081
- //
2082
- // Calculate powf(normalized_entropy, exponent) as
2083
- // norm_entropy^exponent = exp(exponent * log(norm_entropy))
2084
- struct ggml_tensor * log_norm_entropy = ggml_log(ctx, norm_entropy);
2085
- struct ggml_tensor * scaled_log = ggml_scale(ctx, log_norm_entropy, sctx->exponent);
2086
- struct ggml_tensor * pow_entropy = ggml_exp(ctx, scaled_log);
2087
- // With pow_entropy computed we can now compute dyn_temp, scaling by
2088
- // (max_temp - min_temp) and then adding min_temp.
2089
- struct ggml_tensor * dyn_temp = ggml_scale_bias(ctx, pow_entropy, max_temp - min_temp, min_temp);
2090
- ggml_set_name(log_norm_entropy, "temp_ext_log_norm_entropy");
2091
- ggml_set_name(scaled_log, "temp_ext_scaled_log");
2092
- ggml_set_name(pow_entropy, "temp_ext_pow_entropy");
2093
- ggml_set_name(dyn_temp, "temp_ext_dyn_temp");
2094
-
2095
- // Scale the logits by the dynamic temperature
2096
- struct ggml_tensor * scaled_logits = ggml_div(ctx, data->logits, dyn_temp);
2097
- ggml_set_name(scaled_logits, "temp_ext_scaled_logits");
2098
-
2099
- data->logits = scaled_logits;
2100
- }
2101
-
2102
- static struct llama_sampler_i llama_sampler_temp_ext_i = {
2103
- /* .name = */ llama_sampler_temp_ext_name,
2104
- /* .accept = */ nullptr,
2105
- /* .apply = */ llama_sampler_temp_ext_apply,
2106
- /* .reset = */ nullptr,
2107
- /* .clone = */ llama_sampler_temp_ext_clone,
2108
- /* .free = */ llama_sampler_temp_ext_free,
2109
- /* .backend_init = */ llama_sampler_temp_ext_backend_init,
2110
- /* .backend_accept = */ nullptr,
2111
- /* .backend_apply = */ llama_sampler_temp_ext_backend_apply,
2112
- /* .backend_set_input = */ nullptr,
2113
- };
2114
-
2115
- struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
2116
- const bool is_empty = temp == 1.0f && delta <= 0.0f;
2117
-
2118
- if (is_empty) {
2119
- return llama_sampler_init_empty("?temp-ext");
2120
- }
2121
-
2122
- auto * res = llama_sampler_init(
2123
- /* .iface = */ &llama_sampler_temp_ext_i,
2124
- /* .ctx = */ new llama_sampler_temp_ext {
2125
- ("temp-ext"),
2126
- /* .temp = */ temp,
2127
- /* .delta = */ delta,
2128
- /* .exponent = */ exponent,
2129
- }
2130
- );
2131
-
2132
- return res;
2133
- }
2134
-
2135
- // xtc
2136
-
2137
- struct llama_sampler_xtc {
2138
- const float probability;
2139
- const float threshold;
2140
- const size_t min_keep;
2141
-
2142
- const uint32_t seed;
2143
- uint32_t seed_cur;
2144
-
2145
- std::mt19937 rng;
2146
- };
2147
-
2148
- static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
2149
- return "xtc";
2150
- }
2151
-
2152
- static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2153
- auto * ctx = (llama_sampler_xtc *) smpl->ctx;
2154
-
2155
- if (ctx->probability <= 0.0f
2156
- || ctx->threshold > 0.5f
2157
- || cur_p->size < 2) {
2158
- return;
2159
- }
2160
-
2161
- std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
2162
- float chance = distribution(ctx->rng);
2163
- if (chance > ctx->probability) {
2164
- return;
2165
- }
2166
-
2167
- llama_sampler_softmax_impl(cur_p, true);
2168
-
2169
- int pos_last = 0;
2170
-
2171
- for (size_t i = 0; i < cur_p->size; ++i) {
2172
- if (cur_p->data[i].p >= ctx->threshold) {
2173
- pos_last = i;
2174
- } else {
2175
- break;
2176
- }
2177
- }
2178
-
2179
- if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) {
2180
- cur_p->data += pos_last;
2181
- cur_p->size -= pos_last;
2182
- }
2183
- }
2184
-
2185
- static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
2186
- const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
2187
- auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->min_keep, ctx->seed);
2188
-
2189
- // copy the state
2190
- {
2191
- auto * result_ctx = (llama_sampler_xtc *) result->ctx;
2192
-
2193
- result_ctx->rng = ctx->rng;
2194
- }
2195
-
2196
- return result;
2197
- }
2198
-
2199
- static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
2200
- delete (llama_sampler_xtc *) smpl->ctx;
2201
- }
2202
-
2203
- static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
2204
- auto * ctx = (llama_sampler_xtc *) smpl->ctx;
2205
- ctx->seed_cur = get_rng_seed(ctx->seed);
2206
- ctx->rng.seed(ctx->seed_cur);
2207
- }
2208
-
2209
- static struct llama_sampler_i llama_sampler_xtc_i = {
2210
- /* .name = */ llama_sampler_xtc_name,
2211
- /* .accept = */ nullptr,
2212
- /* .apply = */ llama_sample_xtc_apply,
2213
- /* .reset = */ llama_sampler_xtc_reset,
2214
- /* .clone = */ llama_sampler_xtc_clone,
2215
- /* .free = */ llama_sampler_xtc_free,
2216
- /* .backend_init = */ nullptr,
2217
- /* .backend_accept = */ nullptr,
2218
- /* .backend_apply = */ nullptr,
2219
- /* .backend_set_input = */ nullptr,
2220
- };
2221
-
2222
- struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
2223
- const bool is_empty = (p <= 0.0f || t > 0.5f);
2224
-
2225
- if (is_empty) {
2226
- return llama_sampler_init_empty("?xtc");
2227
- }
2228
-
2229
- const auto seed_cur = get_rng_seed(seed);
2230
-
2231
- return llama_sampler_init(
2232
- /* .iface = */ &llama_sampler_xtc_i,
2233
- /* .ctx = */ new llama_sampler_xtc {
2234
- /* .probability = */ p,
2235
- /* .threshold = */ t,
2236
- /* .min_keep = */ min_keep,
2237
- /* .seed = */ seed,
2238
- /* .seed_cur = */ seed_cur,
2239
- /* .rng = */ std::mt19937(seed_cur),
2240
- }
2241
- );
2242
- }
2243
-
2244
- // mirostat
2245
-
2246
- struct llama_sampler_mirostat {
2247
- const int32_t n_vocab;
2248
-
2249
- const uint32_t seed;
2250
- uint32_t seed_cur;
2251
-
2252
- const float tau;
2253
- const float eta;
2254
-
2255
- const int32_t m;
2256
-
2257
- float mu;
2258
-
2259
- std::mt19937 rng;
2260
- };
2261
-
2262
- static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) {
2263
- return "mirostat";
2264
- }
2265
-
2266
- static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2267
- auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
2268
-
2269
- llama_sampler_softmax_impl(cur_p, true);
2270
-
2271
- // Estimate s_hat using the most probable m tokens
2272
- float s_hat = 0.0;
2273
- float sum_ti_bi = 0.0;
2274
- float sum_ti_sq = 0.0;
2275
- for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) {
2276
- float t_i = logf(float(i + 2) / float(i + 1));
2277
- float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p);
2278
- sum_ti_bi += t_i * b_i;
2279
- sum_ti_sq += t_i * t_i;
2280
- }
2281
- s_hat = sum_ti_bi / sum_ti_sq;
2282
-
2283
- // Compute k from the estimated s_hat and target surprise value
2284
- float epsilon_hat = s_hat - 1;
2285
- float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat);
2286
-
2287
- llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
2288
-
2289
- llama_sampler_softmax_impl(cur_p, true);
2290
-
2291
- const int idx = llama_sample_dist(cur_p, ctx->rng);
2292
-
2293
- cur_p->selected = idx;
2294
-
2295
- float observed_surprise = -log2f(cur_p->data[idx].p);
2296
- float e = observed_surprise - ctx->tau;
2297
-
2298
- // Update mu using the learning rate and error
2299
- ctx->mu = ctx->mu - ctx->eta * e;
2300
- }
2301
-
2302
- static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sampler * smpl) {
2303
- const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx;
2304
- auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
2305
-
2306
- // copy the state
2307
- {
2308
- auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx;
2309
-
2310
- result_ctx->mu = ctx->mu;
2311
- result_ctx->rng = ctx->rng;
2312
- }
2313
-
2314
- return result;
2315
- }
2316
-
2317
- static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) {
2318
- auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
2319
- ctx->mu = 2.0f*ctx->tau;
2320
- ctx->seed_cur = get_rng_seed(ctx->seed);
2321
- ctx->rng.seed(ctx->seed_cur);
2322
- }
2323
-
2324
- static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
2325
- delete (llama_sampler_mirostat *) smpl->ctx;
2326
- }
2327
-
2328
- static struct llama_sampler_i llama_sampler_mirostat_i = {
2329
- /* .name = */ llama_sampler_mirostat_name,
2330
- /* .accept = */ nullptr,
2331
- /* .apply = */ llama_sampler_mirostat_apply,
2332
- /* .reset = */ llama_sampler_mirostat_reset,
2333
- /* .clone = */ llama_sampler_mirostat_clone,
2334
- /* .free = */ llama_sampler_mirostat_free,
2335
- /* .backend_init = */ nullptr,
2336
- /* .backend_accept = */ nullptr,
2337
- /* .backend_apply = */ nullptr,
2338
- /* .backend_set_input = */ nullptr,
2339
- };
2340
-
2341
- struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
2342
- const auto seed_cur = get_rng_seed(seed);
2343
-
2344
- return llama_sampler_init(
2345
- /* .iface = */ &llama_sampler_mirostat_i,
2346
- /* .ctx = */ new llama_sampler_mirostat {
2347
- /* .n_vocab = */ n_vocab,
2348
- /* .seed = */ seed,
2349
- /* .seed_cur = */ seed_cur,
2350
- /* .tau = */ tau,
2351
- /* .eta = */ eta,
2352
- /* .m = */ m,
2353
- /* .mu = */ 2.0f*tau,
2354
- /* .rng = */ std::mt19937(seed_cur),
2355
- }
2356
- );
2357
- }
2358
-
2359
- // mirostat v2
2360
-
2361
- struct llama_sampler_mirostat_v2 {
2362
- const uint32_t seed;
2363
- uint32_t seed_cur;
2364
-
2365
- const float tau;
2366
- const float eta;
2367
-
2368
- float mu;
2369
-
2370
- std::mt19937 rng;
2371
- };
2372
-
2373
- static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) {
2374
- return "mirostat-v2";
2375
- }
2376
-
2377
- static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2378
- auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
2379
-
2380
- llama_sampler_softmax_impl(cur_p, true);
2381
-
2382
- // Truncate the words with surprise values greater than mu
2383
- cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
2384
- return -log2f(candidate.p) > ctx->mu;
2385
- }));
2386
-
2387
- if (cur_p->size == 0) {
2388
- cur_p->size = 1;
2389
- }
2390
-
2391
- // Normalize the probabilities of the remaining words
2392
- llama_sampler_softmax_impl(cur_p, true);
2393
-
2394
- const int idx = llama_sample_dist(cur_p, ctx->rng);
2395
-
2396
- cur_p->selected = idx;
2397
-
2398
- float observed_surprise = -log2f(cur_p->data[idx].p);
2399
- float e = observed_surprise - ctx->tau;
2400
-
2401
- // Update mu using the learning rate and error
2402
- ctx->mu = ctx->mu - ctx->eta * e;
2403
- }
2404
-
2405
- static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) {
2406
- auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
2407
- ctx->mu = 2.0f*ctx->tau;
2408
- ctx->seed_cur = get_rng_seed(ctx->seed);
2409
- ctx->rng.seed(ctx->seed_cur);
2410
- }
2411
-
2412
- static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) {
2413
- const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx;
2414
-
2415
- auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
2416
-
2417
- // copy the state
2418
- {
2419
- auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx;
2420
-
2421
- result_ctx->mu = ctx->mu;
2422
- result_ctx->rng = ctx->rng;
2423
- }
2424
-
2425
- return result;
2426
- }
2427
-
2428
- static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
2429
- delete (llama_sampler_mirostat_v2 *) smpl->ctx;
2430
- }
2431
-
2432
- static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
2433
- /* .name = */ llama_sampler_mirostat_v2_name,
2434
- /* .accept = */ nullptr,
2435
- /* .apply = */ llama_sampler_mirostat_v2_apply,
2436
- /* .reset = */ llama_sampler_mirostat_v2_reset,
2437
- /* .clone = */ llama_sampler_mirostat_v2_clone,
2438
- /* .free = */ llama_sampler_mirostat_v2_free,
2439
- /* .backend_init = */ nullptr,
2440
- /* .backend_accept = */ nullptr,
2441
- /* .backend_apply = */ nullptr,
2442
- /* .backend_set_input = */ nullptr,
2443
- };
2444
-
2445
- struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
2446
- auto seed_cur = get_rng_seed(seed);
2447
- return llama_sampler_init(
2448
- /* .iface = */ &llama_sampler_mirostat_v2_i,
2449
- /* .ctx = */ new llama_sampler_mirostat_v2 {
2450
- /* .seed = */ seed,
2451
- /* .seed_cur = */ seed_cur,
2452
- /* .tau = */ tau,
2453
- /* .eta = */ eta,
2454
- /* .mu = */ 2.0f*tau,
2455
- /* .rng = */ std::mt19937(seed_cur),
2456
- }
2457
- );
2458
- }
2459
-
2460
- // grammar
2461
-
2462
- struct llama_sampler_grammar {
2463
- const struct llama_vocab * vocab;
2464
-
2465
- std::string grammar_str;
2466
- std::string grammar_root;
2467
-
2468
- struct llama_grammar * grammar;
2469
- };
2470
-
2471
- static const char * llama_sampler_grammar_name(const struct llama_sampler * /*smpl*/) {
2472
- return "grammar";
2473
- }
2474
-
2475
- static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama_token token) {
2476
- auto * ctx = (llama_sampler_grammar *) smpl->ctx;
2477
- if (ctx->grammar) {
2478
- llama_grammar_accept_impl(*ctx->grammar, token);
2479
- }
2480
- }
2481
-
2482
- static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2483
- auto * ctx = (llama_sampler_grammar *) smpl->ctx;
2484
- if (ctx->grammar) {
2485
- llama_grammar_apply_impl(*ctx->grammar, cur_p);
2486
- }
2487
- }
2488
-
2489
- // Fwd declare to break reset --> init_impl --> llama_sampler_grammar_i --> reset cycle.
2490
- static struct llama_sampler * llama_sampler_init_grammar_impl(
2491
- const struct llama_vocab * vocab,
2492
- const char * grammar_str,
2493
- const char * grammar_root,
2494
- bool lazy,
2495
- const char ** trigger_words,
2496
- size_t num_trigger_words,
2497
- const llama_token * trigger_tokens,
2498
- size_t num_trigger_tokens,
2499
- const char ** trigger_patterns,
2500
- size_t num_trigger_patterns);
2501
-
2502
- static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
2503
- auto * ctx = (llama_sampler_grammar *) smpl->ctx;
2504
- if (!ctx->grammar) {
2505
- return;
2506
- }
2507
-
2508
- std::vector<const char *> trigger_patterns_c;
2509
- trigger_patterns_c.reserve(ctx->grammar->trigger_patterns.size());
2510
- for (auto & trigger_pattern : ctx->grammar->trigger_patterns) {
2511
- trigger_patterns_c.push_back(trigger_pattern.pattern.c_str());
2512
- }
2513
-
2514
- auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
2515
- ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(),
2516
- ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
2517
-
2518
- llama_grammar_free_impl(ctx->grammar);
2519
- ctx->grammar = grammar_new;
2520
- }
2521
-
2522
- static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
2523
- const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
2524
-
2525
- auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0, nullptr, 0);
2526
- GGML_ASSERT(result);
2527
-
2528
- // copy the state
2529
- {
2530
- auto * result_ctx = (llama_sampler_grammar *) result->ctx;
2531
-
2532
- if (ctx->grammar) {
2533
- result_ctx->grammar_str = ctx->grammar_str;
2534
- result_ctx->grammar_root = ctx->grammar_root;
2535
-
2536
- result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar);
2537
- }
2538
- }
2539
-
2540
- return result;
2541
- }
2542
-
2543
- static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
2544
- const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
2545
-
2546
- if (ctx->grammar) {
2547
- llama_grammar_free_impl(ctx->grammar);
2548
- }
2549
-
2550
- delete ctx;
2551
- }
2552
-
2553
- static struct llama_sampler_i llama_sampler_grammar_i = {
2554
- /* .name = */ llama_sampler_grammar_name,
2555
- /* .accept = */ llama_sampler_grammar_accept_impl,
2556
- /* .apply = */ llama_sampler_grammar_apply,
2557
- /* .reset = */ llama_sampler_grammar_reset,
2558
- /* .clone = */ llama_sampler_grammar_clone,
2559
- /* .free = */ llama_sampler_grammar_free,
2560
- /* .backend_init = */ nullptr,
2561
- /* .backend_accept = */ nullptr,
2562
- /* .backend_apply = */ nullptr,
2563
- /* .backend_set_input = */ nullptr,
2564
- };
2565
-
2566
- static struct llama_sampler * llama_sampler_init_grammar_impl(
2567
- const struct llama_vocab * vocab,
2568
- const char * grammar_str,
2569
- const char * grammar_root,
2570
- bool lazy,
2571
- const char ** trigger_words,
2572
- size_t num_trigger_words,
2573
- const llama_token * trigger_tokens,
2574
- size_t num_trigger_tokens,
2575
- const char ** trigger_patterns,
2576
- size_t num_trigger_patterns) {
2577
- auto * ctx = new llama_sampler_grammar;
2578
-
2579
- if (grammar_str != nullptr && grammar_str[0] != '\0') {
2580
- std::string trigger_pattern;
2581
- llama_grammar * grammar = nullptr;
2582
- // TODO: remove trigger_words support.
2583
- if (trigger_words != nullptr && num_trigger_words > 0) {
2584
- GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0);
2585
- trigger_pattern = "[\\s\\S]*?(";
2586
- for (size_t i = 0; i < num_trigger_words; ++i) {
2587
- static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
2588
- if (i > 0) {
2589
- trigger_pattern += "|";
2590
- }
2591
- trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0");
2592
- }
2593
- trigger_pattern += ")[\\s\\S]*";
2594
-
2595
- std::array<const char *, 1> tmp_trigger_patterns = { trigger_pattern.c_str() };
2596
- grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, tmp_trigger_patterns.data(), tmp_trigger_patterns.size(), trigger_tokens, num_trigger_tokens);
2597
- } else {
2598
- grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens);
2599
- }
2600
- *ctx = {
2601
- /* .vocab = */ vocab,
2602
- /* .grammar_str = */ grammar_str,
2603
- /* .grammar_root = */ grammar_root,
2604
- /* .grammar = */ grammar,
2605
- };
2606
- if (!ctx->grammar) {
2607
- delete ctx;
2608
- return nullptr;
2609
- }
2610
- } else {
2611
- *ctx = {
2612
- /* .vocab = */ vocab,
2613
- /* .grammar_str = */ {},
2614
- /* .grammar_root = */ {},
2615
- /* .grammar = */ nullptr,
2616
- };
2617
- }
2618
-
2619
- return llama_sampler_init(
2620
- /* .iface = */ &llama_sampler_grammar_i,
2621
- /* .ctx = */ ctx
2622
- );
2623
- }
2624
-
2625
- struct llama_sampler * llama_sampler_init_grammar(
2626
- const struct llama_vocab * vocab,
2627
- const char * grammar_str,
2628
- const char * grammar_root) {
2629
- return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0, nullptr, 0);
2630
- }
2631
-
2632
- struct llama_sampler * llama_sampler_init_grammar_lazy(
2633
- const struct llama_vocab * vocab,
2634
- const char * grammar_str,
2635
- const char * grammar_root,
2636
- const char ** trigger_words,
2637
- size_t num_trigger_words,
2638
- const llama_token * trigger_tokens,
2639
- size_t num_trigger_tokens) {
2640
- return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens, nullptr, 0);
2641
- }
2642
-
2643
- struct llama_sampler * llama_sampler_init_grammar_lazy_patterns(
2644
- const struct llama_vocab * vocab,
2645
- const char * grammar_str,
2646
- const char * grammar_root,
2647
- const char ** trigger_patterns,
2648
- size_t num_trigger_patterns,
2649
- const llama_token * trigger_tokens,
2650
- size_t num_trigger_tokens) {
2651
- return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, nullptr, 0, trigger_tokens, num_trigger_tokens, trigger_patterns, num_trigger_patterns);
2652
- }
2653
-
2654
- // penalties
2655
-
2656
- struct llama_sampler_penalties {
2657
- const int32_t penalty_last_n;
2658
- const float penalty_repeat;
2659
- const float penalty_freq;
2660
- const float penalty_present;
2661
-
2662
- ring_buffer<llama_token> prev;
2663
-
2664
- // a frequency map to count token occurrences
2665
- std::unordered_map<llama_token, int> token_count;
2666
- };
2667
-
2668
- static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
2669
- return "penalties";
2670
- }
2671
-
2672
- static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_token token) {
2673
- auto * ctx = (llama_sampler_penalties *) smpl->ctx;
2674
- if (ctx->penalty_last_n == 0) {
2675
- return;
2676
- }
2677
-
2678
- ctx->token_count[token]++;
2679
-
2680
- // if the ring buffer is full, remove the oldest token
2681
- if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) {
2682
- const auto old = ctx->prev.front();
2683
-
2684
- ctx->token_count[old]--;
2685
- if (ctx->token_count[old] == 0) {
2686
- ctx->token_count.erase(old);
2687
- }
2688
- }
2689
-
2690
- ctx->prev.push_back(token);
2691
-
2692
- #if 0
2693
- // sanity check
2694
- std::unordered_map<llama_token, int> tmp;
2695
- for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
2696
- tmp[ctx->prev.rat(i)]++;
2697
- }
2698
-
2699
- assert(ctx->token_count == tmp);
2700
- #endif
2701
- }
2702
-
2703
- static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2704
- auto * ctx = (llama_sampler_penalties *) smpl->ctx;
2705
-
2706
- if ((ctx->penalty_last_n == 0) ||
2707
- (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
2708
- return;
2709
- }
2710
-
2711
- // Apply frequency and presence penalties to the cur_p
2712
- for (size_t i = 0; i < cur_p->size; ++i) {
2713
- const auto token_iter = ctx->token_count.find(cur_p->data[i].id);
2714
- if (token_iter == ctx->token_count.end()) {
2715
- continue;
2716
- }
2717
-
2718
- const int count = token_iter->second;
2719
-
2720
- assert(count > 0 && count <= ctx->penalty_last_n);
2721
-
2722
- // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
2723
- // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
2724
- if (cur_p->data[i].logit <= 0) {
2725
- cur_p->data[i].logit *= ctx->penalty_repeat;
2726
- } else {
2727
- cur_p->data[i].logit /= ctx->penalty_repeat;
2728
- }
2729
-
2730
- cur_p->data[i].logit -= float(count) * ctx->penalty_freq + float(count > 0) * ctx->penalty_present;
2731
- }
2732
-
2733
- cur_p->sorted = false;
2734
- }
2735
-
2736
- static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
2737
- auto * ctx = (llama_sampler_penalties *) smpl->ctx;
2738
- ctx->prev.clear();
2739
- ctx->token_count.clear();
2740
- }
2741
-
2742
- static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
2743
- const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
2744
- auto * result = llama_sampler_init_penalties(
2745
- ctx->penalty_last_n,
2746
- ctx->penalty_repeat,
2747
- ctx->penalty_freq,
2748
- ctx->penalty_present);
2749
-
2750
- // copy the state
2751
- {
2752
- auto * result_ctx = (llama_sampler_penalties *) result->ctx;
2753
-
2754
- result_ctx->prev = ctx->prev;
2755
- }
2756
-
2757
- return result;
2758
- }
2759
-
2760
- static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
2761
- delete (llama_sampler_penalties *) smpl->ctx;
2762
- }
2763
-
2764
- static struct llama_sampler_i llama_sampler_penalties_i = {
2765
- /* .name = */ llama_sampler_penalties_name,
2766
- /* .accept = */ llama_sampler_penalties_accept,
2767
- /* .apply = */ llama_sampler_penalties_apply,
2768
- /* .reset = */ llama_sampler_penalties_reset,
2769
- /* .clone = */ llama_sampler_penalties_clone,
2770
- /* .free = */ llama_sampler_penalties_free,
2771
- /* .backend_init = */ nullptr,
2772
- /* .backend_accept = */ nullptr,
2773
- /* .backend_apply = */ nullptr,
2774
- /* .backend_set_input = */ nullptr,
2775
- };
2776
-
2777
- struct llama_sampler * llama_sampler_init_penalties(
2778
- int32_t penalty_last_n,
2779
- float penalty_repeat,
2780
- float penalty_freq,
2781
- float penalty_present) {
2782
- penalty_last_n = std::max(penalty_last_n, 0);
2783
-
2784
- const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f));
2785
-
2786
- if (is_empty) {
2787
- return llama_sampler_init_empty("?penalties");
2788
- }
2789
-
2790
- return llama_sampler_init(
2791
- /* .iface = */ &llama_sampler_penalties_i,
2792
- /* .ctx = */ new llama_sampler_penalties {
2793
- /* .penalty_last_n = */ penalty_last_n,
2794
- /* .penalty_repeat = */ penalty_repeat,
2795
- /* .penalty_freq = */ penalty_freq,
2796
- /* .penalty_present = */ penalty_present,
2797
- /* .prev = */ ring_buffer<llama_token>(penalty_last_n),
2798
- /* .token_count = */ {},
2799
- }
2800
- );
2801
- }
2802
-
2803
- // top-n-sigma
2804
-
2805
- struct llama_sampler_top_n_sigma {
2806
- const float n;
2807
- };
2808
-
2809
- static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) {
2810
- return "top-n-sigma";
2811
- }
2812
-
2813
- static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2814
- auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
2815
-
2816
- if (ctx->n <= 0.0f || cur_p->size <= 1) {
2817
- return;
2818
- }
2819
-
2820
- // find max logit and calculate mean
2821
- float max = cur_p->data[0].logit;
2822
- float logits_sum = 0;
2823
- size_t valid_count = 0;
2824
- for (size_t i = 0; i < cur_p->size; ++i) {
2825
- // Only count non-negative infinity values
2826
- if (cur_p->data[i].logit != -INFINITY) {
2827
- max = std::max(max, cur_p->data[i].logit);
2828
- logits_sum += cur_p->data[i].logit;
2829
- valid_count++;
2830
- }
2831
- }
2832
- float mean = valid_count > 0 ? logits_sum/valid_count : 0;
2833
-
2834
- // calculate standard deviation
2835
- float acc = 0;
2836
- for (size_t i = 0; i < cur_p->size; ++i) {
2837
- // Skip -infinity in std calculation
2838
- if (cur_p->data[i].logit != -INFINITY) {
2839
- acc += pow(cur_p->data[i].logit - mean, 2);
2840
- }
2841
- }
2842
- float std = valid_count > 0 ? sqrt(acc/valid_count) : 0;
2843
-
2844
- // apply mask
2845
- for (size_t i = 0; i < cur_p->size; ++i) {
2846
- if (cur_p->data[i].logit < max - (ctx->n * std)) {
2847
- cur_p->data[i].logit = -INFINITY;
2848
- }
2849
- }
2850
-
2851
- llama_sampler_softmax_impl(cur_p, true);
2852
- }
2853
-
2854
- static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) {
2855
- const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx;
2856
- return llama_sampler_init_top_n_sigma(ctx->n);
2857
- }
2858
-
2859
- static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
2860
- delete (llama_sampler_top_n_sigma *) smpl->ctx;
2861
- }
2862
-
2863
- static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
2864
- /* .name = */ llama_sampler_top_n_sigma_name,
2865
- /* .accept = */ nullptr,
2866
- /* .apply = */ llama_sampler_top_n_sigma_apply,
2867
- /* .reset = */ nullptr,
2868
- /* .clone = */ llama_sampler_top_n_sigma_clone,
2869
- /* .free = */ llama_sampler_top_n_sigma_free,
2870
- /* .backend_init = */ nullptr,
2871
- /* .backend_accept = */ nullptr,
2872
- /* .backend_apply = */ nullptr,
2873
- /* .backend_set_input = */ nullptr,
2874
- };
2875
-
2876
- struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
2877
- const bool is_empty = (n <= 0.0f);
2878
-
2879
- if (is_empty) {
2880
- return llama_sampler_init_empty("?top-n-sigma");
2881
- }
2882
-
2883
- return llama_sampler_init(
2884
- /* .iface = */ &llama_sampler_top_n_sigma_i,
2885
- /* .ctx = */ new llama_sampler_top_n_sigma {
2886
- /* .n = */ n,
2887
- }
2888
- );
2889
- }
2890
-
2891
- // DRY
2892
-
2893
- struct llama_sampler_dry {
2894
- int32_t total_context_size;
2895
-
2896
- const float dry_multiplier;
2897
- const float dry_base;
2898
- const int32_t dry_allowed_length;
2899
- const int32_t dry_penalty_last_n;
2900
-
2901
- std::unordered_multimap<llama_token, std::vector<llama_token>> dry_processed_breakers;
2902
- std::vector<int> dry_repeat_count;
2903
- std::unordered_map<llama_token, int> dry_max_token_repeat;
2904
- ring_buffer<llama_token> last_tokens;
2905
- };
2906
-
2907
- // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
2908
- static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
2909
- for (llama_token token_id = 0; token_id < (llama_token) vocab.n_tokens(); token_id++) {
2910
- std::string word = vocab.detokenize({token_id}, true);
2911
- if (word.find(str) != std::string::npos) {
2912
- token_sequences.emplace(token_id, std::vector<llama_token>());
2913
- } else {
2914
- size_t word_len = word.size();
2915
- size_t str_len = str.size();
2916
- size_t pos = -1;
2917
- while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
2918
- bool match = true;
2919
- size_t i;
2920
- for (i = 1; i < str_len && i + pos < word_len; ++i) {
2921
- if (word[pos + i] != str[i]) {
2922
- match = false;
2923
- break;
2924
- }
2925
- }
2926
- if (match) {
2927
- std::vector<llama_token> tokenization = vocab.tokenize(str.substr(i), false, false);
2928
- if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
2929
- tokenization.resize(max_tail_len);
2930
- }
2931
-
2932
- // Ensure we don't already have a duplicate matching tokenization
2933
- auto its = token_sequences.equal_range(token_id);
2934
- bool found = false;
2935
- for (auto it = its.first; it != its.second; ++it) {
2936
- if (tokenization == it->second) {
2937
- found = true;
2938
- break;
2939
- }
2940
- }
2941
- if (!found) {
2942
- token_sequences.emplace(token_id, tokenization);
2943
- }
2944
- }
2945
- }
2946
- }
2947
- }
2948
- }
2949
-
2950
- static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/) {
2951
- return "dry";
2952
- }
2953
-
2954
- static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) {
2955
- auto * ctx = (llama_sampler_dry *) smpl->ctx;
2956
- if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
2957
- return;
2958
- }
2959
-
2960
- ctx->last_tokens.push_back(token);
2961
- }
2962
-
2963
- // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
2964
- static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
2965
- auto * ctx = (llama_sampler_dry *) smpl->ctx;
2966
-
2967
- if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
2968
- return;
2969
- }
2970
-
2971
- int32_t effective_dry_penalty_last_n = (ctx->dry_penalty_last_n == -1) ? ctx->total_context_size : std::max(ctx->dry_penalty_last_n, 0);
2972
- int last_n_repeat = std::min(std::min((int)ctx->last_tokens.size(), effective_dry_penalty_last_n), ctx->total_context_size);
2973
-
2974
- if (last_n_repeat <= ctx->dry_allowed_length) {
2975
- return;
2976
- }
2977
-
2978
- ctx->dry_repeat_count.assign(last_n_repeat, 0);
2979
- ctx->dry_max_token_repeat.clear();
2980
-
2981
- // Step 1: Look for restart sequences to limit the maximum repetition length.
2982
- // Work backwards through the context looking for any token that begins a restart sequence.
2983
- //
2984
- // The collection `restart_sequences` is a mapping from a "head" token to all "tail"
2985
- // sequences that together comprise a restart sequence. This allows us to quickly check
2986
- // whether each token is the head of a complete sequence. Most restart sequences are actually
2987
- // a single token, and for these the "tail" is an empty vector.
2988
- //
2989
- // If the token is a "head", test all restart sequences that begin with this token
2990
- // (there will often only be one sequence for each token, but if sequences like 'aaaq1' and
2991
- // 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The
2992
- // longest matching sequence (if any) is used to limit the maximum repetition length.
2993
- //
2994
- // Note that in the case case of a short sequence contained in a longer one, this might fail to
2995
- // find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as
2996
- // restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress
2997
- // 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare.
2998
- //
2999
- // This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we
3000
- // have already clamped the maximum tail sequence length when generating `restart_sequences`.
3001
- // With clamping, this scan is O(N) in the context length.
3002
-
3003
- int rep_limit = last_n_repeat;
3004
- for (int i = 0; i < last_n_repeat; ++i) {
3005
- llama_token token = ctx->last_tokens.rat(i);
3006
- auto its = ctx->dry_processed_breakers.equal_range(token);
3007
- if (its.first == ctx->dry_processed_breakers.end()) {
3008
- continue;
3009
- }
3010
- int longest_match = -1;
3011
- for (auto it = its.first; it != its.second; ++it) {
3012
- // Note that (*it) does not contain the head character, so seq_len will be
3013
- // the restart sequence length minus 1.
3014
- // In the common case of a single-token restart sequence, (*it) will be empty
3015
- // and we will trivially match.
3016
- int seq_len = (int)it->second.size();
3017
- if (seq_len > longest_match && seq_len <= (int)i) {
3018
- bool match = true;
3019
- for (int offset = 0; offset < seq_len; ++offset) {
3020
- // The -1 when indexing `last_tokens` is because we already matched the head.
3021
- if (it->second[offset] != ctx->last_tokens.rat(i - offset - 1)) {
3022
- match = false;
3023
- break;
3024
- }
3025
- }
3026
- if (match) {
3027
- longest_match = seq_len;
3028
- }
3029
- }
3030
- }
3031
- if (longest_match >= 0) {
3032
- // We found a restart sequence starting `i` tokens from the end and continuing for
3033
- // `longest_match` tokens.
3034
- rep_limit = i - longest_match;
3035
- break;
3036
- }
3037
- }
3038
- if (rep_limit < ctx->dry_allowed_length) {
3039
- return;
3040
- }
3041
-
3042
- // Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in
3043
- // the reverse direction) to efficiently compute the positions and lengths of suffixes appearing
3044
- // elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences.
3045
- //
3046
- // This algorithm is not currently documented on Wikipedia, but there is a clear description here:
3047
- // https://ivanyu.me/blog/2014/10/15/z-algorithm/
3048
- //
3049
- // The code below is adapted from the public domain implementation by the same author here:
3050
- // https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py
3051
- //
3052
- // Example:
3053
- // Last N tokens: a b c c b c y a b c
3054
- // Repeat counts: 0 0 3 1 0 2 0 0 0 0
3055
- // ^
3056
- // This `3` means that the last three tokens of the context (a b c) also appear here.
3057
- //
3058
- // This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested
3059
- // for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each
3060
- // repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables
3061
- // ensure that the inner while loops only examine each token in the context once as the outer
3062
- // for loop iterates over the context.
3063
-
3064
- {
3065
- const int last = last_n_repeat - 1;
3066
-
3067
- int rt = 0;
3068
- int lt = 0;
3069
-
3070
- for (int k = 1; k < last_n_repeat; ++k) {
3071
- if (k > rt) {
3072
- // If k is outside the current Z-box, do naive computation.
3073
- int n = 0;
3074
- while (n + k < last_n_repeat && ctx->last_tokens.rat(n) == ctx->last_tokens.rat(n+k)) {
3075
- ++n;
3076
- }
3077
- ctx->dry_repeat_count[last - k] = std::min(n, rep_limit);
3078
- if (n > 0) {
3079
- lt = k;
3080
- rt = k + n - 1;
3081
- }
3082
- } else {
3083
- // If k is inside the current Z-box, consider two cases.
3084
-
3085
- int p = k - lt; // Pair index.
3086
- int right_part_len = rt - k + 1;
3087
-
3088
- if (ctx->dry_repeat_count[last - p] < right_part_len) {
3089
- int n = std::min(ctx->dry_repeat_count[last - p], rep_limit);
3090
- ctx->dry_repeat_count[last - k] = n;
3091
- } else {
3092
- int i = rt + 1;
3093
- while (i < last_n_repeat && ctx->last_tokens.rat(i) == ctx->last_tokens.rat(i - k)) {
3094
- i += 1;
3095
- }
3096
-
3097
- int n = std::min(i - k, rep_limit);
3098
- ctx->dry_repeat_count[last - k] = n;
3099
- lt = k;
3100
- rt = i - 1;
3101
- }
3102
- }
3103
- }
3104
- }
3105
-
3106
- // Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length
3107
- // that would be generated by emitting each new token that would extend a sequence.
3108
- //
3109
- // Following the same example as above:
3110
- // Last N tokens: a b c c b c y a b c
3111
- // Repeat counts: 0 0 3 1 0 2 0 0 0 0
3112
- //
3113
- // For each non-zero, look ahead one token. This token, if emitted, would extend the repetition.
3114
- // c: 3 -> 4 (from `a b c` to `a b c c`)
3115
- // b: 1 -> 2 (from `c` to `c b`)
3116
- // y: 2 -> 3 (from `b c` to `b c y`)
3117
-
3118
- for (int i = 0; i < last_n_repeat - 1; ++i) {
3119
- int repeat_len = ctx->dry_repeat_count[i];
3120
- if (repeat_len >= ctx->dry_allowed_length) {
3121
- // This token ends a repeat, so the next token would continue one.
3122
- // By convention, the value of `repeat_len` only includes the tokens currently
3123
- // in the context, not the new token that would be added.
3124
- llama_token token = ctx->last_tokens.rat(last_n_repeat - 2 - i);
3125
- // Track the maximum sequence ending in this token.
3126
- const auto& it = ctx->dry_max_token_repeat.find(token);
3127
- if (it == ctx->dry_max_token_repeat.end() || it->second < repeat_len) {
3128
- ctx->dry_max_token_repeat[token] = repeat_len;
3129
- }
3130
- }
3131
- }
3132
-
3133
- // Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens.
3134
-
3135
- // Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`.
3136
- // Compute it from `penalty_base` and the approximate log of `std::numeric_limits<float>::max()`
3137
- const float FLOAT_MAX_LOG = 88.7228391f;
3138
- int max_exponent = 0;
3139
- if (ctx->dry_base > 1.000001f) {
3140
- max_exponent = FLOAT_MAX_LOG / std::log(ctx->dry_base);
3141
- }
3142
-
3143
- for (size_t i = 0; i < cur_p->size; ++i) {
3144
- const auto& af_kvp = ctx->dry_max_token_repeat.find(cur_p->data[i].id);
3145
- if (af_kvp != ctx->dry_max_token_repeat.end()) {
3146
- // Check all sequence breakers starting with this token
3147
- auto range = ctx->dry_processed_breakers.equal_range(cur_p->data[i].id);
3148
- bool is_single_token_breaker = false;
3149
-
3150
- for (auto it = range.first; it != range.second; ++it) {
3151
- if (it->second.empty()) {
3152
- is_single_token_breaker = true;
3153
- break;
3154
- }
3155
- }
3156
-
3157
- // Apply penalty only if it's not a single-token sequence breaker
3158
- if (!is_single_token_breaker) {
3159
- int repeat_exp = af_kvp->second - ctx->dry_allowed_length;
3160
- if (max_exponent > 0 && repeat_exp > max_exponent) {
3161
- repeat_exp = max_exponent;
3162
- }
3163
- float penalty = ctx->dry_multiplier * std::pow(ctx->dry_base, repeat_exp);
3164
- cur_p->data[i].logit -= penalty;
3165
- }
3166
- }
3167
- }
3168
-
3169
- cur_p->sorted = false;
3170
- }
3171
-
3172
- static void llama_sampler_dry_reset(struct llama_sampler * smpl) {
3173
- auto * ctx = (llama_sampler_dry *) smpl->ctx;
3174
- ctx->last_tokens.clear();
3175
- ctx->dry_repeat_count.clear();
3176
- ctx->dry_max_token_repeat.clear();
3177
- }
3178
-
3179
- static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) {
3180
- const auto * ctx = (llama_sampler_dry *) smpl->ctx;
3181
-
3182
- llama_vocab dummy_vocab;
3183
-
3184
- // dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying
3185
- auto * result = llama_sampler_init_dry(&dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
3186
-
3187
- // Copy the state, including the processed breakers
3188
- {
3189
- auto * result_ctx = (llama_sampler_dry *) result->ctx;
3190
- result_ctx->dry_processed_breakers = ctx->dry_processed_breakers;
3191
- result_ctx->dry_repeat_count = ctx->dry_repeat_count;
3192
- result_ctx->dry_max_token_repeat = ctx->dry_max_token_repeat;
3193
- result_ctx->last_tokens = ctx->last_tokens;
3194
- }
3195
-
3196
- return result;
3197
- }
3198
-
3199
- static void llama_sampler_dry_free(struct llama_sampler * smpl) {
3200
- delete (llama_sampler_dry *) smpl->ctx;
3201
- }
3202
-
3203
- static struct llama_sampler_i llama_sampler_dry_i = {
3204
- /* .name = */ llama_sampler_dry_name,
3205
- /* .accept = */ llama_sampler_dry_accept,
3206
- /* .apply = */ llama_sampler_dry_apply,
3207
- /* .reset = */ llama_sampler_dry_reset,
3208
- /* .clone = */ llama_sampler_dry_clone,
3209
- /* .free = */ llama_sampler_dry_free,
3210
- /* .backend_init = */ nullptr,
3211
- /* .backend_accept = */ nullptr,
3212
- /* .backend_apply = */ nullptr,
3213
- /* .backend_set_input = */ nullptr,
3214
- };
3215
-
3216
- struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
3217
- int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? n_ctx_train : std::max(dry_penalty_last_n, 0);
3218
- std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
3219
- const int MAX_CHAR_LEN = 40;
3220
- const int MAX_SEQ_LEN = 20;
3221
-
3222
- const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
3223
-
3224
- if (!dry_enabled) {
3225
- return llama_sampler_init_empty("?dry");
3226
- }
3227
-
3228
- if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
3229
- // Process sequence breakers
3230
- for (size_t i = 0; i < num_breakers; ++i) {
3231
- if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) {
3232
- LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i);
3233
- continue;
3234
- }
3235
-
3236
- std::string sequence_break(seq_breakers[i]);
3237
- if (sequence_break.empty()) {
3238
- LLAMA_LOG_WARN("skipping empty DRY sequence breaker\n");
3239
- continue;
3240
- }
3241
-
3242
- if (sequence_break.size() > MAX_CHAR_LEN) {
3243
- LLAMA_LOG_WARN("truncating DRY sequence breaker to %d characters\n", MAX_CHAR_LEN);
3244
- sequence_break.resize(MAX_CHAR_LEN);
3245
- }
3246
-
3247
- get_overlapping_token_sequences(*vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
3248
- }
3249
- }
3250
-
3251
- return llama_sampler_init(
3252
- /* .iface = */ &llama_sampler_dry_i,
3253
- /* .ctx = */ new llama_sampler_dry {
3254
- /* .total_context_size = */ n_ctx_train,
3255
- /* .dry_multiplier = */ dry_multiplier,
3256
- /* .dry_base = */ dry_base,
3257
- /* .dry_allowed_length = */ dry_allowed_length,
3258
- /* .dry_penalty_last_n = */ dry_penalty_last_n,
3259
- /* .dry_processed_breakers = */ std::move(processed_breakers),
3260
- /* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
3261
- /* .dry_max_token_repeat = */ {},
3262
- /* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
3263
- }
3264
- );
3265
- }
3266
-
3267
- // wrapper for test-sampling.cpp
3268
- struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
3269
- llama_vocab dummy_vocab;
3270
- auto * result = llama_sampler_init_dry(&dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
3271
- auto * ctx = (llama_sampler_dry *) result->ctx;
3272
-
3273
- // Process the token-based sequence breakers
3274
- ctx->dry_processed_breakers.clear();
3275
- if (seq_breakers.empty()) {
3276
- LLAMA_LOG_WARN("empty DRY sequence breakers list in llama_sampler_init_dry_testing\n");
3277
- } else {
3278
- for (const auto& breaker : seq_breakers) {
3279
- if (breaker.empty()) {
3280
- LLAMA_LOG_WARN("skipping DRY empty sequence breaker\n");
3281
- continue;
3282
- }
3283
- llama_token head_token = breaker[0];
3284
- std::vector<llama_token> tail_tokens(breaker.begin() + 1, breaker.end());
3285
- ctx->dry_processed_breakers.emplace(head_token, std::move(tail_tokens));
3286
- }
3287
-
3288
- if (ctx->dry_processed_breakers.empty()) {
3289
- LLAMA_LOG_WARN("no valid DRY sequence breakers processed in llama_sampler_init_dry_testing\n");
3290
- }
3291
- }
3292
-
3293
- return result;
3294
- }
3295
-
3296
- // logit-bias
3297
-
3298
- struct llama_sampler_logit_bias : public llama_sampler_backend {
3299
- const int32_t n_vocab;
3300
-
3301
- const std::vector<llama_logit_bias> logit_bias;
3302
-
3303
- std::vector<llama_logit_bias> to_search;
3304
-
3305
- struct ggml_tensor * inp_logit_bias;
3306
- struct ggml_tensor * inp_logit_idxs;
3307
-
3308
- ggml_context_ptr inp_ctx;
3309
- ggml_backend_buffer_ptr inp_buf;
3310
- };
3311
-
3312
- static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) {
3313
- auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
3314
- return ctx->get_name();
3315
- }
3316
-
3317
- static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
3318
- auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
3319
-
3320
- if (ctx->logit_bias.empty()) {
3321
- return;
3322
- }
3323
-
3324
- ctx->to_search.clear();
3325
-
3326
- // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id)
3327
- for (const auto & lb : ctx->logit_bias) {
3328
- if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) {
3329
- cur_p->data[lb.token].logit += lb.bias;
3330
- } else {
3331
- ctx->to_search.push_back(lb);
3332
- }
3333
- }
3334
-
3335
- if (ctx->to_search.empty()) {
3336
- return;
3337
- }
3338
-
3339
- // search for the remaining candidates that were not found in the previous step
3340
- for (size_t i = 0; i < cur_p->size; ++i) {
3341
- for (const auto & lb : ctx->to_search) {
3342
- if (cur_p->data[i].id == lb.token) {
3343
- cur_p->data[i].logit += lb.bias;
3344
- break;
3345
- }
3346
- }
3347
- }
3348
- }
3349
-
3350
- static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) {
3351
- const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
3352
- return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
3353
- }
3354
-
3355
- static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
3356
- delete (llama_sampler_logit_bias *) smpl->ctx;
3357
- }
3358
-
3359
- static void llama_sampler_logit_bias_backend_apply(
3360
- struct llama_sampler * smpl,
3361
- struct ggml_context * ctx,
3362
- struct ggml_cgraph * gf,
3363
- struct llama_sampler_data * data) {
3364
- GGML_UNUSED(gf);
3365
- GGML_UNUSED(ctx);
3366
-
3367
- auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
3368
- if (sctx->logit_bias.empty()) {
3369
- return;
3370
- }
3371
-
3372
- ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f);
3373
-
3374
- cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur));
3375
- cur = ggml_set_rows(ctx, cur, sctx->inp_logit_bias, sctx->inp_logit_idxs);
3376
- cur = ggml_reshape_1d(ctx, cur, ggml_nelements(cur));
3377
-
3378
- data->logits = ggml_add(ctx, data->logits, cur);
3379
- }
3380
-
3381
- static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * smpl) {
3382
- auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
3383
- if (sctx->logit_bias.empty()) {
3384
- return;
3385
- }
3386
-
3387
- GGML_ASSERT(sctx->inp_logit_bias != nullptr);
3388
- GGML_ASSERT(sctx->inp_logit_idxs != nullptr);
3389
-
3390
- const size_t n = sctx->logit_bias.size();
3391
-
3392
- std::vector<float> data_logit_bias(n, 0.0f);
3393
- std::vector<int32_t> data_logit_idxs(n, 0);
3394
- for (size_t i = 0; i < n; ++i) {
3395
- const auto & lb = sctx->logit_bias[i];
3396
- GGML_ASSERT(lb.token >= 0 && lb.token < (int32_t) sctx->n_vocab);
3397
- data_logit_bias[i] = lb.bias;
3398
- data_logit_idxs[i] = lb.token;
3399
- }
3400
-
3401
- ggml_backend_tensor_set(sctx->inp_logit_bias, data_logit_bias.data(), 0, ggml_nbytes(sctx->inp_logit_bias));
3402
- ggml_backend_tensor_set(sctx->inp_logit_idxs, data_logit_idxs.data(), 0, ggml_nbytes(sctx->inp_logit_idxs));
3403
- }
3404
-
3405
- static bool llama_sampler_logit_bias_backend_init(
3406
- struct llama_sampler * smpl,
3407
- ggml_backend_buffer_type_t buft) {
3408
- auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
3409
-
3410
- sctx->init(true);
3411
-
3412
- if (sctx->logit_bias.empty()) {
3413
- return true;
3414
- }
3415
-
3416
- ggml_init_params params = {
3417
- /*.mem_size =*/ 2*ggml_tensor_overhead(),
3418
- /*.mem_buffer =*/ nullptr,
3419
- /*.no_alloc =*/ true,
3420
- };
3421
-
3422
- sctx->inp_ctx.reset(ggml_init(params));
3423
-
3424
- const size_t n = sctx->logit_bias.size();
3425
-
3426
- sctx->inp_logit_bias = ggml_new_tensor_2d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1, n);
3427
- ggml_set_name(sctx->inp_logit_bias, "logit_bias");
3428
- ggml_set_input(sctx->inp_logit_bias);
3429
-
3430
- sctx->inp_logit_idxs = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_I32, n);
3431
- ggml_set_name(sctx->inp_logit_idxs, "logit_idxs");
3432
- ggml_set_input(sctx->inp_logit_idxs);
3433
-
3434
- // Allocate all tensors from our context to the backend
3435
- sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
3436
-
3437
- ggml_backend_buffer_clear(sctx->inp_buf.get(), 0);
3438
-
3439
- return true;
3440
- }
3441
-
3442
- static struct llama_sampler_i llama_sampler_logit_bias_i = {
3443
- /* .name = */ llama_sampler_logit_bias_name,
3444
- /* .accept = */ nullptr,
3445
- /* .apply = */ llama_sampler_logit_bias_apply,
3446
- /* .reset = */ nullptr,
3447
- /* .clone = */ llama_sampler_logit_bias_clone,
3448
- /* .free = */ llama_sampler_logit_bias_free,
3449
- /* .backend_init = */ llama_sampler_logit_bias_backend_init,
3450
- /* .backend_accept = */ nullptr,
3451
- /* .backend_apply = */ llama_sampler_logit_bias_backend_apply,
3452
- /* .backend_set_input = */ llama_sampler_logit_bias_backend_set_input,
3453
- };
3454
-
3455
- struct llama_sampler * llama_sampler_init_logit_bias(
3456
- int32_t n_vocab,
3457
- int32_t n_logit_bias,
3458
- const llama_logit_bias * logit_bias) {
3459
- const bool is_empty = n_logit_bias <= 0;
3460
-
3461
- if (is_empty) {
3462
- return llama_sampler_init_empty("?logit-bias");
3463
- }
3464
-
3465
- return llama_sampler_init(
3466
- /* .iface = */ &llama_sampler_logit_bias_i,
3467
- /* .ctx = */ new llama_sampler_logit_bias {
3468
- ("logit-bias"),
3469
- /* .n_vocab = */ n_vocab,
3470
- /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
3471
- /* .to_search = */ {},
3472
- /* .inp_logit_bias = */ nullptr,
3473
- /* .inp_logit_idxs = */ nullptr,
3474
- /* .inp_ctx = */ nullptr,
3475
- /* .inp_buf = */ nullptr,
3476
- }
3477
- );
3478
- }
3479
-
3480
- // infill
3481
-
3482
- //#define GGML_DEBUG_SAMPLER_INFILL
3483
-
3484
- struct llama_sampler_infill {
3485
- const struct llama_vocab * vocab;
3486
-
3487
- std::vector<char> buf0;
3488
- std::vector<char> buf1;
3489
- };
3490
-
3491
- static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
3492
- return "infill";
3493
- }
3494
-
3495
- static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
3496
- auto * ctx = (llama_sampler_infill *) smpl->ctx;
3497
-
3498
- llama_sampler_softmax_impl(cur_p, true);
3499
-
3500
- #if defined(GGML_DEBUG_SAMPLER_INFILL)
3501
- #define LOG_DBG_CUR LLAMA_LOG_DEBUG
3502
- #else
3503
- #define LOG_DBG_CUR(...)
3504
- #endif
3505
-
3506
- for (size_t i = 0; i < cur_p->size; ++i) {
3507
- LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
3508
- }
3509
-
3510
- float p_txt_sum = 0.0f;
3511
- float p_eog_sum = 0.0f;
3512
-
3513
- for (size_t i = 0; i < cur_p->size; ++i) {
3514
- if (ctx->vocab->is_eog(cur_p->data[i].id)) {
3515
- p_eog_sum += cur_p->data[i].p;
3516
- } else {
3517
- p_txt_sum += cur_p->data[i].p;
3518
- }
3519
- }
3520
-
3521
- const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat);
3522
-
3523
- LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);
3524
-
3525
- if (3*p_eog_sum*cur_p->size > p_txt_sum) {
3526
- LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);
3527
-
3528
- // keep just the EOG tokens
3529
- const auto size_org = cur_p->size;
3530
-
3531
- cur_p->size = 0;
3532
-
3533
- float p_sum = 0.0f;
3534
-
3535
- for (size_t i = 0; i < size_org; ++i) {
3536
- if (ctx->vocab->is_eog(cur_p->data[i].id)) {
3537
- p_sum += cur_p->data[i].p;
3538
-
3539
- cur_p->data[cur_p->size++] = cur_p->data[i];
3540
- }
3541
- }
3542
-
3543
- // normalize probs
3544
- for (size_t i = 0; i < cur_p->size; ++i) {
3545
- cur_p->data[i].p /= p_sum;
3546
- }
3547
-
3548
- return;
3549
- }
3550
-
3551
- size_t n_combined = 0; GGML_UNUSED(n_combined);
3552
-
3553
- // combine tokens with common prefix
3554
- for (size_t i0 = 0; i0 < cur_p->size; ++i0) {
3555
- for (size_t i1 = 0; i1 < cur_p->size; ++i1) {
3556
- if (cur_p->data[i0].logit == -INFINITY) {
3557
- break;
3558
- }
3559
-
3560
- if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) {
3561
- continue;
3562
- }
3563
-
3564
- int len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
3565
- if (len0 < 0) {
3566
- ctx->buf0.resize(len0);
3567
- len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
3568
- assert(len0 > 0);
3569
- }
3570
-
3571
- int len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
3572
- if (len1 < 0) {
3573
- ctx->buf1.resize(len1);
3574
- len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
3575
- assert(len1 > 0);
3576
- }
3577
-
3578
- // token i0 is a prefix of token i1
3579
- if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) {
3580
- int dst = i0;
3581
- int src = i1;
3582
-
3583
- // merge into the token with higher probability
3584
- if (cur_p->data[i1].p > cur_p->data[i0].p) {
3585
- std::swap(dst, src);
3586
- }
3587
-
3588
- cur_p->data[dst].p += cur_p->data[src].p;
3589
- cur_p->data[src].logit = -INFINITY;
3590
- cur_p->data[src].p = 0.0f;
3591
-
3592
- n_combined++;
3593
- }
3594
- }
3595
- }
3596
-
3597
- size_t n_non_eog = 0;
3598
-
3599
- size_t size_org = cur_p->size;
3600
-
3601
- float p_sum = 0.0f;
3602
- float thold = 0.2f;
3603
-
3604
- cur_p->size = 0;
3605
-
3606
- LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
3607
-
3608
- for (size_t i = 0; i < size_org; ++i) {
3609
- const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id);
3610
-
3611
- if (cur_p->data[i].p < thold && !is_eog) {
3612
- continue;
3613
- }
3614
-
3615
- if (!is_eog) {
3616
- ++n_non_eog;
3617
- }
3618
-
3619
- p_sum += cur_p->data[i].p;
3620
-
3621
- // keep this token
3622
- cur_p->data[cur_p->size++] = cur_p->data[i];
3623
- }
3624
-
3625
- LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog);
3626
-
3627
- // if no non-EOG tokens are left -> reduce cur_p to single EOT token
3628
- if (n_non_eog == 0) {
3629
- cur_p->size = 1;
3630
- cur_p->data[0].id = ctx->vocab->token_eot();
3631
- if (cur_p->data[0].id == LLAMA_TOKEN_NULL) {
3632
- cur_p->data[0].id = ctx->vocab->token_eos();
3633
- }
3634
- cur_p->data[0].logit = 1.0f;
3635
-
3636
- GGML_ASSERT(cur_p->data[0].id != LLAMA_TOKEN_NULL);
3637
-
3638
- return;
3639
- }
3640
-
3641
- // normalize probs
3642
- for (size_t i = 0; i < cur_p->size; ++i) {
3643
- cur_p->data[i].p /= p_sum;
3644
-
3645
- LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
3646
- }
3647
-
3648
- size_org = cur_p->size;
3649
- p_sum = 0.0f;
3650
- thold = 1.0/(n_non_eog + 1);
3651
-
3652
- cur_p->size = 0;
3653
-
3654
- LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
3655
-
3656
- for (size_t i = 0; i < size_org; ++i) {
3657
- const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id);
3658
-
3659
- if (cur_p->data[i].p < thold && !is_eog) {
3660
- continue;
3661
- }
3662
-
3663
- p_sum += cur_p->data[i].p;
3664
-
3665
- cur_p->data[cur_p->size++] = cur_p->data[i];
3666
- }
3667
-
3668
- // normalize probs
3669
- for (size_t i = 0; i < cur_p->size; ++i) {
3670
- cur_p->data[i].p /= p_sum;
3671
-
3672
- LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
3673
- }
3674
-
3675
- #undef LOG_DBG_CUR
3676
- }
3677
-
3678
- static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
3679
- const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
3680
- return llama_sampler_init_infill(ctx->vocab);
3681
- }
3682
-
3683
- static void llama_sampler_infill_free(struct llama_sampler * smpl) {
3684
- delete (llama_sampler_infill *) smpl->ctx;
3685
- }
3686
-
3687
- static struct llama_sampler_i llama_sampler_infill_i = {
3688
- /* .name = */ llama_sampler_infill_name,
3689
- /* .accept = */ nullptr,
3690
- /* .apply = */ llama_sampler_infill_apply,
3691
- /* .reset = */ nullptr,
3692
- /* .clone = */ llama_sampler_infill_clone,
3693
- /* .free = */ llama_sampler_infill_free,
3694
- /* .backend_apply = */ nullptr,
3695
- /* .backend_accept = */ nullptr,
3696
- /* .backend_set_input = */ nullptr,
3697
- /* .backend_init = */ nullptr,
3698
- };
3699
-
3700
- struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
3701
- return llama_sampler_init(
3702
- /* .iface = */ &llama_sampler_infill_i,
3703
- /* .ctx = */ new llama_sampler_infill {
3704
- /* .vocab = */ vocab,
3705
- /* .buf0 = */ std::vector<char>(512),
3706
- /* .buf1 = */ std::vector<char>(512),
3707
- }
3708
- );
3709
- }
3710
-
3711
- // utils
3712
-
3713
- uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
3714
- if (smpl->iface == &llama_sampler_dist_i) {
3715
- return ((const llama_sampler_dist *) smpl->ctx)->seed_cur;
3716
- }
3717
-
3718
- if (smpl->iface == &llama_sampler_mirostat_i) {
3719
- return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur;
3720
- }
3721
-
3722
- if (smpl->iface == &llama_sampler_mirostat_v2_i) {
3723
- return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur;
3724
- }
3725
-
3726
- if (smpl->iface == &llama_sampler_chain_i) {
3727
- const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
3728
- for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
3729
- const uint32_t seed = llama_sampler_get_seed(it->ptr);
3730
- if (seed != LLAMA_DEFAULT_SEED) {
3731
- return seed;
3732
- }
3733
- }
3734
- }
3735
-
3736
- return LLAMA_DEFAULT_SEED;
3737
- }
3738
-
3739
- // perf
3740
-
3741
- struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * chain) {
3742
- struct llama_perf_sampler_data data = {};
3743
-
3744
- if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
3745
- GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
3746
- }
3747
-
3748
- const auto * ctx = (const struct llama_sampler_chain *) chain->ctx;
3749
-
3750
- data.t_sample_ms = 1e-3 * ctx->t_sample_us;
3751
- data.n_sample = std::max(0, ctx->n_sample);
3752
-
3753
- return data;
3754
- }
3755
-
3756
- void llama_perf_sampler_print(const struct llama_sampler * chain) {
3757
- const auto data = llama_perf_sampler(chain);
3758
-
3759
- LLAMA_LOG_INFO("%s: samplers time = %10.2f ms / %5d runs\n", __func__, data.t_sample_ms, data.n_sample);
3760
- }
3761
-
3762
- void llama_perf_sampler_reset(struct llama_sampler * chain) {
3763
- if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
3764
- GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
3765
- }
3766
-
3767
- auto * ctx = (struct llama_sampler_chain *) chain->ctx;
3768
-
3769
- ctx->t_sample_us = 0;
3770
- ctx->n_sample = 0;
3771
- }