local-llm-rn 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (626) hide show
  1. package/cpp/CMakeLists.txt +285 -0
  2. package/cpp/common/CMakeLists.txt +149 -0
  3. package/cpp/common/arg.cpp +3799 -0
  4. package/cpp/common/arg.h +131 -0
  5. package/cpp/common/base64.hpp +392 -0
  6. package/cpp/common/build-info.cpp.in +4 -0
  7. package/cpp/common/chat-parser-xml-toolcall.cpp +879 -0
  8. package/cpp/common/chat-parser-xml-toolcall.h +45 -0
  9. package/cpp/common/chat-parser.cpp +1649 -0
  10. package/cpp/common/chat-parser.h +133 -0
  11. package/cpp/common/chat-peg-parser.cpp +124 -0
  12. package/cpp/common/chat-peg-parser.h +105 -0
  13. package/cpp/common/chat.cpp +3355 -0
  14. package/cpp/common/chat.h +252 -0
  15. package/cpp/common/common.cpp +1824 -0
  16. package/cpp/common/common.h +930 -0
  17. package/cpp/common/console.cpp +1137 -0
  18. package/cpp/common/console.h +41 -0
  19. package/cpp/common/debug.cpp +167 -0
  20. package/cpp/common/debug.h +43 -0
  21. package/cpp/common/download.cpp +792 -0
  22. package/cpp/common/download.h +84 -0
  23. package/cpp/common/http.h +84 -0
  24. package/cpp/common/jinja/README.md +88 -0
  25. package/cpp/common/jinja/caps.cpp +285 -0
  26. package/cpp/common/jinja/caps.h +30 -0
  27. package/cpp/common/jinja/lexer.cpp +341 -0
  28. package/cpp/common/jinja/lexer.h +157 -0
  29. package/cpp/common/jinja/parser.cpp +591 -0
  30. package/cpp/common/jinja/parser.h +21 -0
  31. package/cpp/common/jinja/runtime.cpp +867 -0
  32. package/cpp/common/jinja/runtime.h +638 -0
  33. package/cpp/common/jinja/string.cpp +213 -0
  34. package/cpp/common/jinja/string.h +61 -0
  35. package/cpp/common/jinja/utils.h +149 -0
  36. package/cpp/common/jinja/value.cpp +1393 -0
  37. package/cpp/common/jinja/value.h +756 -0
  38. package/cpp/common/json-partial.cpp +324 -0
  39. package/cpp/common/json-partial.h +39 -0
  40. package/cpp/common/json-schema-to-grammar.cpp +1153 -0
  41. package/cpp/common/json-schema-to-grammar.h +43 -0
  42. package/cpp/common/llguidance.cpp +258 -0
  43. package/cpp/common/log.cpp +446 -0
  44. package/cpp/common/log.h +119 -0
  45. package/cpp/common/ngram-cache.cpp +285 -0
  46. package/cpp/common/ngram-cache.h +101 -0
  47. package/cpp/common/ngram-map.cpp +530 -0
  48. package/cpp/common/ngram-map.h +115 -0
  49. package/cpp/common/ngram-mod.cpp +60 -0
  50. package/cpp/common/ngram-mod.h +38 -0
  51. package/cpp/common/peg-parser.cpp +1712 -0
  52. package/cpp/common/peg-parser.h +459 -0
  53. package/cpp/common/preset.cpp +483 -0
  54. package/cpp/common/preset.h +83 -0
  55. package/cpp/common/regex-partial.cpp +204 -0
  56. package/cpp/common/regex-partial.h +56 -0
  57. package/cpp/common/sampling.cpp +745 -0
  58. package/cpp/common/sampling.h +119 -0
  59. package/cpp/common/speculative.cpp +1074 -0
  60. package/cpp/common/speculative.h +41 -0
  61. package/cpp/common/unicode.cpp +64 -0
  62. package/cpp/common/unicode.h +22 -0
  63. package/cpp/ggml/CMakeLists.txt +494 -0
  64. package/cpp/ggml/cmake/GitVars.cmake +22 -0
  65. package/cpp/ggml/cmake/common.cmake +50 -0
  66. package/cpp/ggml/cmake/ggml-config.cmake.in +191 -0
  67. package/cpp/ggml/include/ggml-alloc.h +85 -0
  68. package/cpp/ggml/include/ggml-backend.h +373 -0
  69. package/cpp/ggml/include/ggml-blas.h +25 -0
  70. package/cpp/ggml/include/ggml-cann.h +123 -0
  71. package/cpp/ggml/include/ggml-cpp.h +39 -0
  72. package/cpp/ggml/include/ggml-cpu.h +151 -0
  73. package/cpp/ggml/include/ggml-cuda.h +47 -0
  74. package/cpp/ggml/include/ggml-hexagon.h +19 -0
  75. package/cpp/ggml/include/ggml-metal.h +61 -0
  76. package/cpp/ggml/include/ggml-opencl.h +26 -0
  77. package/cpp/ggml/include/ggml-opt.h +256 -0
  78. package/cpp/ggml/include/ggml-rpc.h +30 -0
  79. package/cpp/ggml/include/ggml-sycl.h +49 -0
  80. package/cpp/ggml/include/ggml-virtgpu.h +14 -0
  81. package/cpp/ggml/include/ggml-vulkan.h +29 -0
  82. package/cpp/ggml/include/ggml-webgpu.h +19 -0
  83. package/cpp/ggml/include/ggml-zdnn.h +17 -0
  84. package/cpp/ggml/include/ggml-zendnn.h +22 -0
  85. package/cpp/ggml/include/ggml.h +2753 -0
  86. package/cpp/ggml/include/gguf.h +204 -0
  87. package/cpp/ggml/src/CMakeLists.txt +492 -0
  88. package/cpp/ggml/src/ggml-alloc.c +1244 -0
  89. package/cpp/ggml/src/ggml-backend-dl.cpp +48 -0
  90. package/cpp/ggml/src/ggml-backend-dl.h +45 -0
  91. package/cpp/ggml/src/ggml-backend-impl.h +255 -0
  92. package/cpp/ggml/src/ggml-backend-reg.cpp +566 -0
  93. package/cpp/ggml/src/ggml-backend.cpp +2270 -0
  94. package/cpp/ggml/src/ggml-blas/CMakeLists.txt +101 -0
  95. package/cpp/ggml/src/ggml-blas/ggml-blas.cpp +518 -0
  96. package/cpp/ggml/src/ggml-common.h +1878 -0
  97. package/cpp/ggml/src/ggml-cpu/CMakeLists.txt +691 -0
  98. package/cpp/ggml/src/ggml-cpu/amx/amx.cpp +247 -0
  99. package/cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  100. package/cpp/ggml/src/ggml-cpu/amx/common.h +91 -0
  101. package/cpp/ggml/src/ggml-cpu/amx/mmq.cpp +2512 -0
  102. package/cpp/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  103. package/cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +98 -0
  104. package/cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4052 -0
  105. package/cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +4935 -0
  106. package/cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2159 -0
  107. package/cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  108. package/cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
  109. package/cpp/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  110. package/cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2726 -0
  111. package/cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
  112. package/cpp/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  113. package/cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
  114. package/cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
  115. package/cpp/ggml/src/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
  116. package/cpp/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
  117. package/cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
  118. package/cpp/ggml/src/ggml-cpu/arch-fallback.h +313 -0
  119. package/cpp/ggml/src/ggml-cpu/binary-ops.cpp +154 -0
  120. package/cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  121. package/cpp/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
  122. package/cpp/ggml/src/ggml-cpu/common.h +95 -0
  123. package/cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +529 -0
  124. package/cpp/ggml/src/ggml-cpu/ggml-cpu.c +3734 -0
  125. package/cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +701 -0
  126. package/cpp/ggml/src/ggml-cpu/hbm.cpp +55 -0
  127. package/cpp/ggml/src/ggml-cpu/hbm.h +8 -0
  128. package/cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +938 -0
  129. package/cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +90 -0
  130. package/cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +798 -0
  131. package/cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  132. package/cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +4033 -0
  133. package/cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +25 -0
  134. package/cpp/ggml/src/ggml-cpu/ops.cpp +10978 -0
  135. package/cpp/ggml/src/ggml-cpu/ops.h +116 -0
  136. package/cpp/ggml/src/ggml-cpu/quants.c +1193 -0
  137. package/cpp/ggml/src/ggml-cpu/quants.h +97 -0
  138. package/cpp/ggml/src/ggml-cpu/repack.cpp +3316 -0
  139. package/cpp/ggml/src/ggml-cpu/repack.h +173 -0
  140. package/cpp/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  141. package/cpp/ggml/src/ggml-cpu/simd-mappings.h +1279 -0
  142. package/cpp/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
  143. package/cpp/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  144. package/cpp/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  145. package/cpp/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  146. package/cpp/ggml/src/ggml-cpu/traits.cpp +36 -0
  147. package/cpp/ggml/src/ggml-cpu/traits.h +38 -0
  148. package/cpp/ggml/src/ggml-cpu/unary-ops.cpp +337 -0
  149. package/cpp/ggml/src/ggml-cpu/unary-ops.h +35 -0
  150. package/cpp/ggml/src/ggml-cpu/vec.cpp +629 -0
  151. package/cpp/ggml/src/ggml-cpu/vec.h +1585 -0
  152. package/cpp/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
  153. package/cpp/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3232 -0
  154. package/cpp/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -0
  155. package/cpp/ggml/src/ggml-hexagon/htp/act-ops.c +815 -0
  156. package/cpp/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  157. package/cpp/ggml/src/ggml-hexagon/htp/binary-ops.c +827 -0
  158. package/cpp/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  159. package/cpp/ggml/src/ggml-hexagon/htp/cpy-ops.c +251 -0
  160. package/cpp/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +666 -0
  161. package/cpp/ggml/src/ggml-hexagon/htp/get-rows-ops.c +111 -0
  162. package/cpp/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
  163. package/cpp/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
  164. package/cpp/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  165. package/cpp/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  166. package/cpp/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  167. package/cpp/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  168. package/cpp/ggml/src/ggml-hexagon/htp/htp-msg.h +154 -0
  169. package/cpp/ggml/src/ggml-hexagon/htp/htp-ops.h +65 -0
  170. package/cpp/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  171. package/cpp/ggml/src/ggml-hexagon/htp/hvx-arith.h +470 -0
  172. package/cpp/ggml/src/ggml-hexagon/htp/hvx-base.h +173 -0
  173. package/cpp/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  174. package/cpp/ggml/src/ggml-hexagon/htp/hvx-div.h +116 -0
  175. package/cpp/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  176. package/cpp/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  177. package/cpp/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  178. package/cpp/ggml/src/ggml-hexagon/htp/hvx-inverse.h +176 -0
  179. package/cpp/ggml/src/ggml-hexagon/htp/hvx-reduce.h +266 -0
  180. package/cpp/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  181. package/cpp/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  182. package/cpp/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  183. package/cpp/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  184. package/cpp/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -0
  185. package/cpp/ggml/src/ggml-hexagon/htp/main.c +1150 -0
  186. package/cpp/ggml/src/ggml-hexagon/htp/matmul-ops.c +2595 -0
  187. package/cpp/ggml/src/ggml-hexagon/htp/rope-ops.c +498 -0
  188. package/cpp/ggml/src/ggml-hexagon/htp/set-rows-ops.c +167 -0
  189. package/cpp/ggml/src/ggml-hexagon/htp/softmax-ops.c +421 -0
  190. package/cpp/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +130 -0
  191. package/cpp/ggml/src/ggml-hexagon/htp/unary-ops.c +384 -0
  192. package/cpp/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
  193. package/cpp/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  194. package/cpp/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  195. package/cpp/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  196. package/cpp/ggml/src/ggml-hexagon/libdl.h +79 -0
  197. package/cpp/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  198. package/cpp/ggml/src/ggml-hexagon/op-desc.h +153 -0
  199. package/cpp/ggml/src/ggml-impl.h +724 -0
  200. package/cpp/ggml/src/ggml-metal/CMakeLists.txt +124 -0
  201. package/cpp/ggml/src/ggml-metal/ggml-metal-common.cpp +457 -0
  202. package/cpp/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  203. package/cpp/ggml/src/ggml-metal/ggml-metal-context.h +41 -0
  204. package/cpp/ggml/src/ggml-metal/ggml-metal-context.m +702 -0
  205. package/cpp/ggml/src/ggml-metal/ggml-metal-device.cpp +1890 -0
  206. package/cpp/ggml/src/ggml-metal/ggml-metal-device.h +290 -0
  207. package/cpp/ggml/src/ggml-metal/ggml-metal-device.m +1749 -0
  208. package/cpp/ggml/src/ggml-metal/ggml-metal-impl.h +1054 -0
  209. package/cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp +4370 -0
  210. package/cpp/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
  211. package/cpp/ggml/src/ggml-metal/ggml-metal.cpp +937 -0
  212. package/cpp/ggml/src/ggml-metal/ggml-metal.metal +9819 -0
  213. package/cpp/ggml/src/ggml-musa/CMakeLists.txt +125 -0
  214. package/cpp/ggml/src/ggml-musa/mudnn.cu +112 -0
  215. package/cpp/ggml/src/ggml-musa/mudnn.cuh +12 -0
  216. package/cpp/ggml/src/ggml-opencl/CMakeLists.txt +150 -0
  217. package/cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +11553 -0
  218. package/cpp/ggml/src/ggml-opencl/kernels/add.cl +190 -0
  219. package/cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  220. package/cpp/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  221. package/cpp/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
  222. package/cpp/ggml/src/ggml-opencl/kernels/concat.cl +51 -0
  223. package/cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  224. package/cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  225. package/cpp/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
  226. package/cpp/ggml/src/ggml-opencl/kernels/cvt.cl +417 -0
  227. package/cpp/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
  228. package/cpp/ggml/src/ggml-opencl/kernels/div.cl +138 -0
  229. package/cpp/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
  230. package/cpp/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
  231. package/cpp/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  232. package/cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  233. package/cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
  234. package/cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  235. package/cpp/ggml/src/ggml-opencl/kernels/gelu.cl +89 -0
  236. package/cpp/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  237. package/cpp/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  238. package/cpp/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
  239. package/cpp/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
  240. package/cpp/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  241. package/cpp/ggml/src/ggml-opencl/kernels/get_rows.cl +187 -0
  242. package/cpp/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
  243. package/cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
  244. package/cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
  245. package/cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
  246. package/cpp/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
  247. package/cpp/ggml/src/ggml-opencl/kernels/mul.cl +152 -0
  248. package/cpp/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
  249. package/cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  250. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  251. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
  252. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
  253. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  254. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  255. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  256. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  257. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  258. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
  259. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
  260. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
  261. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
  262. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
  263. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  264. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  265. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  266. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  267. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  268. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  269. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  270. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
  271. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
  272. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
  273. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
  274. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
  275. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  276. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  277. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  278. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32.cl +194 -0
  279. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  280. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  281. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  282. package/cpp/ggml/src/ggml-opencl/kernels/norm.cl +161 -0
  283. package/cpp/ggml/src/ggml-opencl/kernels/pad.cl +39 -0
  284. package/cpp/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
  285. package/cpp/ggml/src/ggml-opencl/kernels/repeat.cl +38 -0
  286. package/cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +190 -0
  287. package/cpp/ggml/src/ggml-opencl/kernels/rope.cl +747 -0
  288. package/cpp/ggml/src/ggml-opencl/kernels/scale.cl +27 -0
  289. package/cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
  290. package/cpp/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  291. package/cpp/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
  292. package/cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +108 -0
  293. package/cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +108 -0
  294. package/cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +107 -0
  295. package/cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +107 -0
  296. package/cpp/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
  297. package/cpp/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  298. package/cpp/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  299. package/cpp/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  300. package/cpp/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  301. package/cpp/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
  302. package/cpp/ggml/src/ggml-opencl/kernels/sum_rows.cl +140 -0
  303. package/cpp/ggml/src/ggml-opencl/kernels/tanh.cl +109 -0
  304. package/cpp/ggml/src/ggml-opencl/kernels/transpose.cl +117 -0
  305. package/cpp/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  306. package/cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  307. package/cpp/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
  308. package/cpp/ggml/src/ggml-opt.cpp +1093 -0
  309. package/cpp/ggml/src/ggml-quants.c +5325 -0
  310. package/cpp/ggml/src/ggml-quants.h +106 -0
  311. package/cpp/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  312. package/cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2118 -0
  313. package/cpp/ggml/src/ggml-threading.cpp +12 -0
  314. package/cpp/ggml/src/ggml-threading.h +14 -0
  315. package/cpp/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  316. package/cpp/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  317. package/cpp/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  318. package/cpp/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  319. package/cpp/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  320. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  321. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  322. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  323. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  324. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  325. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  326. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  327. package/cpp/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  328. package/cpp/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  329. package/cpp/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  330. package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  331. package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  332. package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  333. package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  334. package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  335. package/cpp/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  336. package/cpp/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  337. package/cpp/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  338. package/cpp/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  339. package/cpp/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  340. package/cpp/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  341. package/cpp/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  342. package/cpp/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  343. package/cpp/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  344. package/cpp/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  345. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  346. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  347. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  348. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  349. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  350. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  351. package/cpp/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  352. package/cpp/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  353. package/cpp/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  354. package/cpp/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  355. package/cpp/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  356. package/cpp/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  357. package/cpp/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
  358. package/cpp/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1231 -0
  359. package/cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3150 -0
  360. package/cpp/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  361. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  362. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  363. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  364. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +107 -0
  365. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +923 -0
  366. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
  367. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  368. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +182 -0
  369. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
  370. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +668 -0
  371. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  372. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  373. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +713 -0
  374. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +103 -0
  375. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +138 -0
  376. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +188 -0
  377. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +194 -0
  378. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  379. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
  380. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  381. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
  382. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +109 -0
  383. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  384. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  385. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  386. package/cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  387. package/cpp/ggml/src/ggml-zdnn/common.hpp +59 -0
  388. package/cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +633 -0
  389. package/cpp/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  390. package/cpp/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  391. package/cpp/ggml/src/ggml-zdnn/utils.cpp +79 -0
  392. package/cpp/ggml/src/ggml-zdnn/utils.hpp +19 -0
  393. package/cpp/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  394. package/cpp/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
  395. package/cpp/ggml/src/ggml.c +7669 -0
  396. package/cpp/ggml/src/ggml.cpp +26 -0
  397. package/cpp/ggml/src/gguf.cpp +1699 -0
  398. package/cpp/include/llama-cpp.h +32 -0
  399. package/cpp/include/llama.h +1568 -0
  400. package/cpp/mtmd/CMakeLists.txt +98 -0
  401. package/cpp/mtmd/README.md +63 -0
  402. package/cpp/mtmd/clip-graph.h +117 -0
  403. package/cpp/mtmd/clip-impl.h +586 -0
  404. package/cpp/mtmd/clip-model.h +390 -0
  405. package/cpp/mtmd/clip.cpp +4154 -0
  406. package/cpp/mtmd/clip.h +121 -0
  407. package/cpp/mtmd/deprecation-warning.cpp +22 -0
  408. package/cpp/mtmd/legacy-models/convert_image_encoder_to_gguf.py +412 -0
  409. package/cpp/mtmd/legacy-models/glmedge-convert-image-encoder-to-gguf.py +280 -0
  410. package/cpp/mtmd/legacy-models/glmedge-surgery.py +33 -0
  411. package/cpp/mtmd/legacy-models/llava_surgery.py +38 -0
  412. package/cpp/mtmd/legacy-models/llava_surgery_v2.py +180 -0
  413. package/cpp/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py +892 -0
  414. package/cpp/mtmd/legacy-models/minicpmv-surgery.py +47 -0
  415. package/cpp/mtmd/models/cogvlm.cpp +98 -0
  416. package/cpp/mtmd/models/conformer.cpp +216 -0
  417. package/cpp/mtmd/models/glm4v.cpp +122 -0
  418. package/cpp/mtmd/models/internvl.cpp +69 -0
  419. package/cpp/mtmd/models/kimik25.cpp +101 -0
  420. package/cpp/mtmd/models/kimivl.cpp +63 -0
  421. package/cpp/mtmd/models/llama4.cpp +96 -0
  422. package/cpp/mtmd/models/llava.cpp +374 -0
  423. package/cpp/mtmd/models/minicpmv.cpp +114 -0
  424. package/cpp/mtmd/models/mobilenetv5.cpp +451 -0
  425. package/cpp/mtmd/models/models.h +128 -0
  426. package/cpp/mtmd/models/nemotron-v2-vl.cpp +35 -0
  427. package/cpp/mtmd/models/paddleocr.cpp +52 -0
  428. package/cpp/mtmd/models/pixtral.cpp +86 -0
  429. package/cpp/mtmd/models/qwen2vl.cpp +183 -0
  430. package/cpp/mtmd/models/qwen3vl.cpp +193 -0
  431. package/cpp/mtmd/models/siglip.cpp +86 -0
  432. package/cpp/mtmd/models/whisper-enc.cpp +115 -0
  433. package/cpp/mtmd/models/youtuvl.cpp +179 -0
  434. package/cpp/mtmd/mtmd-audio.cpp +730 -0
  435. package/cpp/mtmd/mtmd-audio.h +113 -0
  436. package/cpp/mtmd/mtmd-cli.cpp +437 -0
  437. package/cpp/mtmd/mtmd-helper.cpp +521 -0
  438. package/cpp/mtmd/mtmd-helper.h +96 -0
  439. package/cpp/mtmd/mtmd.cpp +1156 -0
  440. package/cpp/mtmd/mtmd.h +319 -0
  441. package/cpp/mtmd/requirements.txt +5 -0
  442. package/cpp/mtmd/test-1.jpeg +0 -0
  443. package/cpp/mtmd/test-2.mp3 +0 -0
  444. package/cpp/mtmd/tests.sh +192 -0
  445. package/cpp/src/CMakeLists.txt +169 -0
  446. package/cpp/src/llama-adapter.cpp +488 -0
  447. package/cpp/src/llama-adapter.h +89 -0
  448. package/cpp/src/llama-arch.cpp +2855 -0
  449. package/cpp/src/llama-arch.h +619 -0
  450. package/cpp/src/llama-batch.cpp +917 -0
  451. package/cpp/src/llama-batch.h +173 -0
  452. package/cpp/src/llama-chat.cpp +896 -0
  453. package/cpp/src/llama-chat.h +71 -0
  454. package/cpp/src/llama-context.cpp +3512 -0
  455. package/cpp/src/llama-context.h +359 -0
  456. package/cpp/src/llama-cparams.cpp +5 -0
  457. package/cpp/src/llama-cparams.h +44 -0
  458. package/cpp/src/llama-grammar.cpp +1464 -0
  459. package/cpp/src/llama-grammar.h +194 -0
  460. package/cpp/src/llama-graph.cpp +2685 -0
  461. package/cpp/src/llama-graph.h +1026 -0
  462. package/cpp/src/llama-hparams.cpp +234 -0
  463. package/cpp/src/llama-hparams.h +339 -0
  464. package/cpp/src/llama-impl.cpp +171 -0
  465. package/cpp/src/llama-impl.h +73 -0
  466. package/cpp/src/llama-io.cpp +15 -0
  467. package/cpp/src/llama-io.h +35 -0
  468. package/cpp/src/llama-kv-cache-iswa.cpp +330 -0
  469. package/cpp/src/llama-kv-cache-iswa.h +137 -0
  470. package/cpp/src/llama-kv-cache.cpp +2271 -0
  471. package/cpp/src/llama-kv-cache.h +388 -0
  472. package/cpp/src/llama-kv-cells.h +533 -0
  473. package/cpp/src/llama-memory-hybrid-iswa.cpp +275 -0
  474. package/cpp/src/llama-memory-hybrid-iswa.h +140 -0
  475. package/cpp/src/llama-memory-hybrid.cpp +268 -0
  476. package/cpp/src/llama-memory-hybrid.h +139 -0
  477. package/cpp/src/llama-memory-recurrent.cpp +1165 -0
  478. package/cpp/src/llama-memory-recurrent.h +182 -0
  479. package/cpp/src/llama-memory.cpp +59 -0
  480. package/cpp/src/llama-memory.h +122 -0
  481. package/cpp/src/llama-mmap.cpp +785 -0
  482. package/cpp/src/llama-mmap.h +92 -0
  483. package/cpp/src/llama-model-loader.cpp +1414 -0
  484. package/cpp/src/llama-model-loader.h +203 -0
  485. package/cpp/src/llama-model-saver.cpp +286 -0
  486. package/cpp/src/llama-model-saver.h +37 -0
  487. package/cpp/src/llama-model.cpp +9253 -0
  488. package/cpp/src/llama-model.h +576 -0
  489. package/cpp/src/llama-quant.cpp +1119 -0
  490. package/cpp/src/llama-quant.h +1 -0
  491. package/cpp/src/llama-sampler.cpp +3885 -0
  492. package/cpp/src/llama-sampler.h +42 -0
  493. package/cpp/src/llama-vocab.cpp +3970 -0
  494. package/cpp/src/llama-vocab.h +187 -0
  495. package/cpp/src/llama.cpp +1313 -0
  496. package/cpp/src/models/afmoe.cpp +191 -0
  497. package/cpp/src/models/apertus.cpp +125 -0
  498. package/cpp/src/models/arcee.cpp +135 -0
  499. package/cpp/src/models/arctic.cpp +138 -0
  500. package/cpp/src/models/arwkv7.cpp +86 -0
  501. package/cpp/src/models/baichuan.cpp +122 -0
  502. package/cpp/src/models/bailingmoe.cpp +144 -0
  503. package/cpp/src/models/bailingmoe2.cpp +135 -0
  504. package/cpp/src/models/bert.cpp +178 -0
  505. package/cpp/src/models/bitnet.cpp +160 -0
  506. package/cpp/src/models/bloom.cpp +101 -0
  507. package/cpp/src/models/chameleon.cpp +178 -0
  508. package/cpp/src/models/chatglm.cpp +132 -0
  509. package/cpp/src/models/codeshell.cpp +111 -0
  510. package/cpp/src/models/cogvlm.cpp +102 -0
  511. package/cpp/src/models/cohere2-iswa.cpp +134 -0
  512. package/cpp/src/models/command-r.cpp +122 -0
  513. package/cpp/src/models/dbrx.cpp +123 -0
  514. package/cpp/src/models/deci.cpp +135 -0
  515. package/cpp/src/models/deepseek.cpp +144 -0
  516. package/cpp/src/models/deepseek2.cpp +262 -0
  517. package/cpp/src/models/delta-net-base.cpp +376 -0
  518. package/cpp/src/models/dots1.cpp +134 -0
  519. package/cpp/src/models/dream.cpp +105 -0
  520. package/cpp/src/models/ernie4-5-moe.cpp +150 -0
  521. package/cpp/src/models/ernie4-5.cpp +110 -0
  522. package/cpp/src/models/eurobert.cpp +97 -0
  523. package/cpp/src/models/exaone-moe.cpp +146 -0
  524. package/cpp/src/models/exaone.cpp +114 -0
  525. package/cpp/src/models/exaone4.cpp +123 -0
  526. package/cpp/src/models/falcon-h1.cpp +111 -0
  527. package/cpp/src/models/falcon.cpp +120 -0
  528. package/cpp/src/models/gemma-embedding.cpp +116 -0
  529. package/cpp/src/models/gemma.cpp +112 -0
  530. package/cpp/src/models/gemma2-iswa.cpp +128 -0
  531. package/cpp/src/models/gemma3.cpp +155 -0
  532. package/cpp/src/models/gemma3n-iswa.cpp +384 -0
  533. package/cpp/src/models/glm4-moe.cpp +170 -0
  534. package/cpp/src/models/glm4.cpp +157 -0
  535. package/cpp/src/models/gpt2.cpp +105 -0
  536. package/cpp/src/models/gptneox.cpp +144 -0
  537. package/cpp/src/models/granite-hybrid.cpp +196 -0
  538. package/cpp/src/models/granite.cpp +211 -0
  539. package/cpp/src/models/grok.cpp +159 -0
  540. package/cpp/src/models/grovemoe.cpp +141 -0
  541. package/cpp/src/models/hunyuan-dense.cpp +132 -0
  542. package/cpp/src/models/hunyuan-moe.cpp +154 -0
  543. package/cpp/src/models/internlm2.cpp +120 -0
  544. package/cpp/src/models/jais.cpp +86 -0
  545. package/cpp/src/models/jais2.cpp +123 -0
  546. package/cpp/src/models/jamba.cpp +106 -0
  547. package/cpp/src/models/kimi-linear.cpp +392 -0
  548. package/cpp/src/models/lfm2.cpp +190 -0
  549. package/cpp/src/models/llada-moe.cpp +122 -0
  550. package/cpp/src/models/llada.cpp +99 -0
  551. package/cpp/src/models/llama-iswa.cpp +178 -0
  552. package/cpp/src/models/llama.cpp +168 -0
  553. package/cpp/src/models/maincoder.cpp +117 -0
  554. package/cpp/src/models/mamba-base.cpp +285 -0
  555. package/cpp/src/models/mamba.cpp +54 -0
  556. package/cpp/src/models/mimo2-iswa.cpp +123 -0
  557. package/cpp/src/models/minicpm3.cpp +200 -0
  558. package/cpp/src/models/minimax-m2.cpp +124 -0
  559. package/cpp/src/models/mistral3.cpp +160 -0
  560. package/cpp/src/models/models.h +684 -0
  561. package/cpp/src/models/modern-bert.cpp +109 -0
  562. package/cpp/src/models/mpt.cpp +126 -0
  563. package/cpp/src/models/nemotron-h.cpp +148 -0
  564. package/cpp/src/models/nemotron.cpp +122 -0
  565. package/cpp/src/models/neo-bert.cpp +104 -0
  566. package/cpp/src/models/olmo.cpp +121 -0
  567. package/cpp/src/models/olmo2.cpp +150 -0
  568. package/cpp/src/models/olmoe.cpp +124 -0
  569. package/cpp/src/models/openai-moe-iswa.cpp +127 -0
  570. package/cpp/src/models/openelm.cpp +124 -0
  571. package/cpp/src/models/orion.cpp +123 -0
  572. package/cpp/src/models/paddleocr.cpp +122 -0
  573. package/cpp/src/models/pangu-embedded.cpp +121 -0
  574. package/cpp/src/models/phi2.cpp +121 -0
  575. package/cpp/src/models/phi3.cpp +152 -0
  576. package/cpp/src/models/plamo.cpp +110 -0
  577. package/cpp/src/models/plamo2.cpp +318 -0
  578. package/cpp/src/models/plamo3.cpp +128 -0
  579. package/cpp/src/models/plm.cpp +169 -0
  580. package/cpp/src/models/qwen.cpp +108 -0
  581. package/cpp/src/models/qwen2.cpp +126 -0
  582. package/cpp/src/models/qwen2moe.cpp +151 -0
  583. package/cpp/src/models/qwen2vl.cpp +117 -0
  584. package/cpp/src/models/qwen3.cpp +117 -0
  585. package/cpp/src/models/qwen35.cpp +386 -0
  586. package/cpp/src/models/qwen35moe.cpp +420 -0
  587. package/cpp/src/models/qwen3moe.cpp +124 -0
  588. package/cpp/src/models/qwen3next.cpp +525 -0
  589. package/cpp/src/models/qwen3vl-moe.cpp +140 -0
  590. package/cpp/src/models/qwen3vl.cpp +132 -0
  591. package/cpp/src/models/refact.cpp +94 -0
  592. package/cpp/src/models/rnd1.cpp +126 -0
  593. package/cpp/src/models/rwkv6-base.cpp +164 -0
  594. package/cpp/src/models/rwkv6.cpp +94 -0
  595. package/cpp/src/models/rwkv6qwen2.cpp +86 -0
  596. package/cpp/src/models/rwkv7-base.cpp +137 -0
  597. package/cpp/src/models/rwkv7.cpp +90 -0
  598. package/cpp/src/models/seed-oss.cpp +124 -0
  599. package/cpp/src/models/smallthinker.cpp +126 -0
  600. package/cpp/src/models/smollm3.cpp +128 -0
  601. package/cpp/src/models/stablelm.cpp +146 -0
  602. package/cpp/src/models/starcoder.cpp +100 -0
  603. package/cpp/src/models/starcoder2.cpp +121 -0
  604. package/cpp/src/models/step35-iswa.cpp +168 -0
  605. package/cpp/src/models/t5-dec.cpp +166 -0
  606. package/cpp/src/models/t5-enc.cpp +96 -0
  607. package/cpp/src/models/wavtokenizer-dec.cpp +149 -0
  608. package/cpp/src/models/xverse.cpp +108 -0
  609. package/cpp/src/unicode-data.cpp +7034 -0
  610. package/cpp/src/unicode-data.h +20 -0
  611. package/cpp/src/unicode.cpp +1103 -0
  612. package/cpp/src/unicode.h +111 -0
  613. package/cpp/vendor/nlohmann/json.hpp +25526 -0
  614. package/cpp/vendor/nlohmann/json_fwd.hpp +187 -0
  615. package/cpp/vendor/stb/stb_image.h +7988 -0
  616. package/ios/LocalLLM-Bridging-Header.h +2 -0
  617. package/ios/LocalLLM.h +5 -0
  618. package/ios/LocalLLM.mm +1267 -0
  619. package/local-llm-rn.podspec +60 -0
  620. package/package.json +35 -0
  621. package/src/NativeLocalLLM.ts +73 -0
  622. package/src/device.ts +50 -0
  623. package/src/download-adapter.ts +17 -0
  624. package/src/index.ts +21 -0
  625. package/src/native-bridge.ts +142 -0
  626. package/src/rn-downloader.ts +37 -0
@@ -0,0 +1,636 @@
1
+ diagnostic(off, chromium.subgroup_matrix_uniformity);
2
+ diagnostic(off, subgroup_uniformity);
3
+ enable f16;
4
+ enable subgroups;
5
+ enable chromium_experimental_subgroup_matrix;
6
+
7
+ #ifdef KV_F32
8
+ #define KV_TYPE f32
9
+ #else
10
+ #define KV_TYPE f16
11
+ #endif
12
+
13
+ // Default values
14
+ #define HEAD_DIM_QK 64
15
+ #define HEAD_DIM_V 64
16
+
17
+ // The number of rows/columns/k in a subgroup matrix. MxK * KxN = MxN
18
+ // Note that the "K" here does not correspond to the K in attention's Q/K/V, it's just the common dimension.
19
+ #define SG_MAT_M 8
20
+ #define SG_MAT_N 8
21
+ #define SG_MAT_K 8
22
+
23
+ // Each workgroup processes one subgroup matrix of Q rows
24
+ #define Q_TILE SG_MAT_M
25
+ #define KV_TILE 16
26
+ #define WG_SIZE 64
27
+
28
+ // Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE.
29
+ #define KV_BLOCKS (KV_TILE / SG_MAT_N)
30
+
31
+ // Quantization constants/helpers
32
+ #define BLOCK_SIZE 32
33
+ #define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE)
34
+ #define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE)
35
+ // number of quantized elements processed per thread
36
+ #if defined(KV_Q4_0)
37
+ #define NQ 16
38
+ // Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights
39
+ #define F16_PER_BLOCK 9
40
+ #define WEIGHTS_PER_F16 4
41
+ #elif defined(KV_Q8_0)
42
+ #define NQ 8
43
+ // Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights
44
+ #define F16_PER_BLOCK 17
45
+ #define WEIGHTS_PER_F16 2
46
+ #endif
47
+ #define F16_PER_THREAD (NQ / WEIGHTS_PER_F16)
48
+
49
+ // Ok not to put these in a define block, compiler will remove if unused
50
+ fn get_byte(value: u32, index: u32) -> u32 {
51
+ return (value >> (index * 8)) & 0xFF;
52
+ }
53
+
54
+ fn get_byte_i32(value: u32, index: u32) -> i32 {
55
+ return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
56
+ }
57
+
58
+ struct Params {
59
+ offset_q: u32,
60
+ offset_k: u32,
61
+ offset_v: u32,
62
+ offset_mask: u32,
63
+ offset_sinks: u32,
64
+ offset_dst: u32,
65
+
66
+ // shapes of Q/K/V
67
+ n_heads: u32,
68
+ seq_len_q: u32,
69
+ seq_len_kv: u32,
70
+
71
+ // strides (in elements)
72
+ stride_q1: u32,
73
+ stride_q2: u32,
74
+ stride_q3: u32,
75
+ stride_k1: u32,
76
+ stride_k2: u32,
77
+ stride_k3: u32,
78
+ stride_v1: u32,
79
+ stride_v2: u32,
80
+ stride_v3: u32,
81
+ stride_mask3: u32,
82
+
83
+ // repeat factors for K/V, e.g., MHA vs. MQA vs. GQA
84
+ q_per_kv: u32,
85
+
86
+ // softmax params
87
+ scale: f32,
88
+ max_bias: f32,
89
+ logit_softcap: f32,
90
+ n_head_log2: f32,
91
+ m0: f32,
92
+ m1: f32,
93
+ };
94
+
95
+ @group(0) @binding(0) var<storage, read_write> Q: array<f32>;
96
+ @group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
97
+ @group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>;
98
+
99
+ #if defined(MASK) && defined(SINKS)
100
+ @group(0) @binding(3) var<storage, read_write> mask: array<f16>;
101
+ @group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
102
+ #define DST_BINDING 5
103
+ #define PARAMS_BINDING 6
104
+ #elif defined(MASK)
105
+ @group(0) @binding(3) var<storage, read_write> mask: array<f16>;
106
+ #define DST_BINDING 4
107
+ #define PARAMS_BINDING 5
108
+ #elif defined(SINKS)
109
+ @group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
110
+ #define DST_BINDING 4
111
+ #define PARAMS_BINDING 5
112
+ #else
113
+ #define DST_BINDING 3
114
+ #define PARAMS_BINDING 4
115
+ #endif
116
+
117
+ @group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
118
+ @group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
119
+
120
+ // Just a very small float value.
121
+ const FLOAT_MIN: f32 = -1.0e9;
122
+
123
+ // The number of Q rows processed per workgroup
124
+ var<workgroup> q_shmem: array<f16, Q_TILE * HEAD_DIM_QK>;
125
+
126
+ #ifndef KV_DIRECT
127
+ const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V);
128
+ // we can reuse the same shmem for K and V since we only need one at a time
129
+ var<workgroup> kv_shmem: array<f16, kv_shmem_size>;
130
+ #endif
131
+
132
+ var<workgroup> o_shmem: array<f16, Q_TILE * HEAD_DIM_V>; // output shmem
133
+
134
+ #ifdef MASK
135
+ // storage for mask values
136
+ var<workgroup> mask_shmem: array<f16, Q_TILE * KV_TILE>;
137
+ #endif
138
+
139
+ // storage for output of Q*K^T scores for online softmax (S matrix from paper)
140
+ // also storage for diagonal matrix during online softmax (P matrix from paper)
141
+ // note that we reuse the same storage for both since we only need one at a time
142
+ var<workgroup> inter_shmem: array<f16, Q_TILE * KV_TILE>;
143
+
144
+ // Storage for row max and exp sum during online softmax
145
+ var<workgroup> row_max_shmem: array<f32, Q_TILE>;
146
+ var<workgroup> exp_sum_shmem: array<f32, Q_TILE>;
147
+
148
+ fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32) -> f32 {
149
+ var v = select(FLOAT_MIN,
150
+ f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale,
151
+ kv_idx < KV_TILE);
152
+ #ifdef LOGIT_SOFTCAP
153
+ v = params.logit_softcap * tanh(v);
154
+ #endif
155
+ #ifdef MASK
156
+ let mask_val = select(0.0, f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE);
157
+ let mask_term = slope * mask_val;
158
+ v += mask_term;
159
+ #endif
160
+ return v;
161
+ }
162
+
163
+ fn load_f32x4(buf: ptr<storage, array<vec4<f32>>, read_write>, scalar_index: u32) -> vec4<f32> {
164
+ return (*buf)[scalar_index >> 2u];
165
+ }
166
+
167
+ fn load_kvx4(buf: ptr<storage, array<vec4<KV_TYPE>>, read_write>, scalar_index: u32) -> vec4<KV_TYPE> {
168
+ return (*buf)[scalar_index >> 2u];
169
+ }
170
+
171
+ @compute @workgroup_size(WG_SIZE)
172
+ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
173
+ @builtin(local_invocation_id) local_id: vec3<u32>,
174
+ @builtin(subgroup_id) subgroup_id: u32,
175
+ @builtin(subgroup_size) subgroup_size: u32,
176
+ @builtin(num_subgroups) num_subgroups: u32,
177
+ @builtin(subgroup_invocation_id) sg_inv_id: u32) {
178
+
179
+ // initialize row max for online softmax
180
+ for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) {
181
+ row_max_shmem[i] = FLOAT_MIN;
182
+ exp_sum_shmem[i] = 0.0;
183
+ }
184
+
185
+ for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) {
186
+ o_shmem[i] = 0.0;
187
+ }
188
+
189
+ // workgroups per head/batch
190
+ let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE;
191
+ let wg_per_batch = wg_per_head * params.n_heads;
192
+
193
+ let dst2_stride = HEAD_DIM_V * params.n_heads;
194
+ let dst3_stride = dst2_stride * params.seq_len_q;
195
+
196
+ // batch index
197
+ let batch_idx = wg_id.x / wg_per_batch;
198
+ let q_batch_offset = params.offset_q + batch_idx * params.stride_q3;
199
+ let k_batch_offset = params.offset_k + batch_idx * params.stride_k3;
200
+ let v_batch_offset = params.offset_v + batch_idx * params.stride_v3;
201
+ let dst_batch_offset = params.offset_dst + batch_idx * dst3_stride;
202
+ let wg_in_batch = wg_id.x % wg_per_batch;
203
+
204
+ // head index
205
+ let head_idx = wg_in_batch / wg_per_head;
206
+ let q_head_offset = q_batch_offset + head_idx * params.stride_q2;
207
+ let k_head_idx = head_idx / params.q_per_kv;
208
+ let v_head_idx = k_head_idx;
209
+ let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2;
210
+ let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2;
211
+
212
+ // starting Q row for this workgroup
213
+ let wg_in_head = wg_in_batch % wg_per_head;
214
+ let q_row_start = wg_in_head * Q_TILE;
215
+
216
+ #ifdef MASK
217
+ // mask offset
218
+ let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv;
219
+ #endif
220
+
221
+ // note that the output is permuted, the layout is [head_dim_v, n_heads, seq_len_q, batch_size]
222
+ let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * HEAD_DIM_V;
223
+
224
+ let head = f32(head_idx);
225
+ let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), params.max_bias > 0);
226
+
227
+ // load q tile into shared memory
228
+ for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
229
+ let q_row = elem_idx / HEAD_DIM_QK;
230
+ let q_col = elem_idx % HEAD_DIM_QK;
231
+ let head_q_row = q_row_start + q_row;
232
+ let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1;
233
+ q_shmem[elem_idx] = f16(select(
234
+ 0.0,
235
+ Q[global_q_row_offset + q_col],
236
+ head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK));
237
+ }
238
+
239
+ for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) {
240
+ // clear inter_shmem to ensure zero-initialized accumulators
241
+ for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
242
+ inter_shmem[elem_idx] = 0.0;
243
+ }
244
+
245
+ // load k tile into shared memory
246
+ #if defined(KV_Q4_0)
247
+ for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
248
+ let blck_idx = elem_idx / BLOCK_SIZE;
249
+ let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
250
+ let k_row = blck_idx / BLOCKS_K;
251
+ let global_k_row = kv_tile + k_row;
252
+ let block_k = blck_idx % BLOCKS_K;
253
+ let row_offset = k_row * HEAD_DIM_QK;
254
+
255
+ if (global_k_row < params.seq_len_kv) {
256
+ let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
257
+ let base_idx = global_block_idx * F16_PER_BLOCK;
258
+ let d = K[base_idx]; // scale
259
+ for (var j = 0u; j < F16_PER_THREAD; j += 2) {
260
+ let q_0 = K[base_idx + 1u + block_offset + j];
261
+ let q_1 = K[base_idx + 1u + block_offset + j + 1];
262
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
263
+ for (var k = 0u; k < 4u; k++) {
264
+ let q_byte = get_byte(q_packed, k);
265
+ let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
266
+ let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
267
+ let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
268
+ kv_shmem[row_offset + idx] = q_lo;
269
+ kv_shmem[row_offset + idx + 16u] = q_hi;
270
+ }
271
+ }
272
+ }
273
+ }
274
+ #elif defined(KV_Q8_0)
275
+ for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
276
+ let blck_idx = elem_idx / BLOCK_SIZE;
277
+ let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
278
+ let k_row = blck_idx / BLOCKS_K;
279
+ let global_k_row = kv_tile + k_row;
280
+ let block_k = blck_idx % BLOCKS_K;
281
+ let row_offset = k_row * HEAD_DIM_QK;
282
+
283
+ if (global_k_row < params.seq_len_kv) {
284
+ let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
285
+ let base_idx = global_block_idx * F16_PER_BLOCK;
286
+ let d = K[base_idx]; // scale
287
+ for (var j = 0u; j < F16_PER_THREAD; j += 2) {
288
+ let q_0 = K[base_idx + 1u + block_offset + j];
289
+ let q_1 = K[base_idx + 1u + block_offset + j + 1];
290
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
291
+ for (var k = 0u; k < 4u; k++) {
292
+ let q_byte = get_byte_i32(q_packed, k);
293
+ let q_val = f16(q_byte) * d;
294
+ let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
295
+ kv_shmem[row_offset + idx] = q_val;
296
+ }
297
+ }
298
+ }
299
+ }
300
+ #elif defined(KV_DIRECT)
301
+ // Direct global loads for KV
302
+ #else
303
+ for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
304
+ let k_row = elem_idx / HEAD_DIM_QK;
305
+ let k_col = elem_idx % HEAD_DIM_QK;
306
+ let global_k_row = kv_tile + k_row;
307
+ let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1;
308
+ kv_shmem[elem_idx] = f16(select(
309
+ 0.0,
310
+ K[global_k_row_offset + k_col],
311
+ global_k_row < params.seq_len_kv && k_col < HEAD_DIM_QK));
312
+ }
313
+ #endif
314
+
315
+ workgroupBarrier();
316
+
317
+ // accumulate q block * k block into registers across the entire KV tile
318
+ // TODO: this loop seems to be the current largest bottleneck
319
+ // this bracket exists to scope the lifetime of variables, reducing register pressure
320
+ {
321
+ #ifdef KV_DIRECT
322
+ let k_block_row = kv_tile + subgroup_id * SG_MAT_N;
323
+ var k_global_offset = k_head_offset + k_block_row * params.stride_k1;
324
+ #else
325
+ var k_block_offset = subgroup_id * SG_MAT_N * HEAD_DIM_QK;
326
+ #endif
327
+ for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) {
328
+ let inter_offset = kv_block * SG_MAT_N;
329
+ var acc: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(&inter_shmem, inter_offset, false, KV_TILE);
330
+
331
+ var q_cur = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, 0u, false, HEAD_DIM_QK);
332
+
333
+ #ifdef KV_DIRECT
334
+ var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + 0u, true, params.stride_k1);
335
+ #else
336
+ var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK);
337
+ #endif
338
+
339
+ var t: u32 = 1u;
340
+ for (; t + 1u < HEAD_DIM_QK / SG_MAT_K; t += 2u) {
341
+ let h0 = t * SG_MAT_K;
342
+ var q0 = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h0, false, HEAD_DIM_QK);
343
+ #ifdef KV_DIRECT
344
+ var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h0, true, params.stride_k1);
345
+ #else
346
+ var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK);
347
+ #endif
348
+ acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
349
+ q_cur = q0;
350
+ k_cur = k0;
351
+
352
+ let h1 = (t + 1u) * SG_MAT_K;
353
+ var q1g = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h1, false, HEAD_DIM_QK);
354
+ #ifdef KV_DIRECT
355
+ var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h1, true, params.stride_k1);
356
+ #else
357
+ var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK);
358
+ #endif
359
+ acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
360
+ q_cur = q1g;
361
+ k_cur = k1g;
362
+ }
363
+
364
+ // handle odd tail
365
+ if (t < HEAD_DIM_QK / SG_MAT_K) {
366
+ let h = t * SG_MAT_K;
367
+ var qn = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h, false, HEAD_DIM_QK);
368
+ #ifdef KV_DIRECT
369
+ var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h, true, params.stride_k1);
370
+ #else
371
+ var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK);
372
+ #endif
373
+ acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
374
+ q_cur = qn;
375
+ k_cur = kn;
376
+ }
377
+
378
+ acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
379
+
380
+ #ifdef KV_DIRECT
381
+ k_global_offset += num_subgroups * SG_MAT_N * params.stride_k1;
382
+ #else
383
+ k_block_offset += num_subgroups * SG_MAT_N * HEAD_DIM_QK;
384
+ #endif
385
+ subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE);
386
+ }
387
+ }
388
+
389
+
390
+ #ifdef MASK
391
+ // load mask tile into shared memory for this KV block
392
+ // TODO: optimize and skip if mask is -INF for the entire tile
393
+ for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
394
+ let mask_row = elem_idx / KV_TILE;
395
+ let mask_col = elem_idx % KV_TILE;
396
+ let global_q_row = q_row_start + mask_row;
397
+ let global_k_col = kv_tile + mask_col;
398
+ let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv;
399
+ let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col;
400
+ mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds);
401
+ }
402
+ #endif
403
+
404
+ workgroupBarrier();
405
+
406
+ // online softmax
407
+ for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) {
408
+ let global_q_row = q_row_start + q_tile_row;
409
+ if (global_q_row >= params.seq_len_q) {
410
+ break;
411
+ }
412
+
413
+ // initialize running max for this row
414
+ var prev_max = row_max_shmem[q_tile_row];
415
+ var final_max = prev_max;
416
+ // pass 1: compute final max across the full KV tile in chunks
417
+ for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
418
+ let kv_idx = kv_offset + sg_inv_id;
419
+ let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope);
420
+ final_max = subgroupMax(max(final_max, softmax_term));
421
+ }
422
+
423
+ var total_exp_term: f32 = 0.0;
424
+ // pass 2: compute exp sum and write P using final_max
425
+ for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
426
+ let kv_idx = kv_offset + sg_inv_id;
427
+ let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope);
428
+ let cur_p = select(0.0,
429
+ exp(softmax_term - final_max),
430
+ kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE);
431
+ total_exp_term += subgroupAdd(cur_p);
432
+ if (kv_idx < KV_TILE) {
433
+ inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p);
434
+ }
435
+ }
436
+
437
+ let cur_exp = exp(prev_max - final_max);
438
+
439
+ if (sg_inv_id == 0) {
440
+ row_max_shmem[q_tile_row] = final_max;
441
+ exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term;
442
+ }
443
+
444
+ for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
445
+ let idx = q_tile_row * HEAD_DIM_V + elem_idx;
446
+ o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp);
447
+ }
448
+ }
449
+
450
+ // load v tile into shared memory
451
+ #if defined(KV_Q4_0)
452
+ for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
453
+ let blck_idx = elem_idx / BLOCK_SIZE;
454
+ let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
455
+ let v_row = blck_idx / BLOCKS_V;
456
+ let global_v_row = kv_tile + v_row;
457
+ let block_k = blck_idx % BLOCKS_V;
458
+ let row_offset = v_row * HEAD_DIM_V;
459
+
460
+ if (global_v_row < params.seq_len_kv) {
461
+ let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
462
+ let base_idx = global_block_idx * F16_PER_BLOCK;
463
+ let d = V[base_idx]; // scale
464
+ for (var j = 0u; j < F16_PER_THREAD; j += 2) {
465
+ let q_0 = V[base_idx + 1u + block_offset + j];
466
+ let q_1 = V[base_idx + 1u + block_offset + j + 1];
467
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
468
+ for (var k = 0u; k < 4u; k++) {
469
+ let q_byte = get_byte(q_packed, k);
470
+ let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
471
+ let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
472
+ let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
473
+ kv_shmem[row_offset + idx] = q_lo;
474
+ kv_shmem[row_offset + idx + 16u] = q_hi;
475
+ }
476
+ }
477
+ }
478
+ }
479
+ #elif defined(KV_Q8_0)
480
+ for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
481
+ let blck_idx = elem_idx / BLOCK_SIZE;
482
+ let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
483
+ let v_row = blck_idx / BLOCKS_V;
484
+ let global_v_row = kv_tile + v_row;
485
+ let block_k = blck_idx % BLOCKS_V;
486
+ let row_offset = v_row * HEAD_DIM_V;
487
+
488
+ if (global_v_row < params.seq_len_kv) {
489
+ let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
490
+ let base_idx = global_block_idx * F16_PER_BLOCK;
491
+ let d = V[base_idx]; // scale
492
+ for (var j = 0u; j < F16_PER_THREAD; j += 2) {
493
+ let q_0 = V[base_idx + 1u + block_offset + j];
494
+ let q_1 = V[base_idx + 1u + block_offset + j + 1];
495
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
496
+ for (var k = 0u; k < 4u; k++) {
497
+ let q_byte = get_byte_i32(q_packed, k);
498
+ let q_val = f16(q_byte) * d;
499
+ let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
500
+ kv_shmem[row_offset + idx] = q_val;
501
+ }
502
+ }
503
+ }
504
+ }
505
+ #elif defined(KV_DIRECT)
506
+ // Direct global loads for KV
507
+ #else
508
+ for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) {
509
+ let v_row = elem_idx / HEAD_DIM_V;
510
+ let v_col = elem_idx % HEAD_DIM_V;
511
+ let global_v_row = kv_tile + v_row;
512
+ let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1;
513
+ kv_shmem[elem_idx] = f16(select(
514
+ 0.0,
515
+ V[global_v_row_offset + v_col],
516
+ global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V));
517
+ }
518
+ #endif
519
+
520
+ workgroupBarrier();
521
+
522
+ // we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem
523
+ // we want to compute O += P * V across the full KV tile
524
+ for (var head_dim_block = subgroup_id * SG_MAT_N;
525
+ head_dim_block < HEAD_DIM_V;
526
+ head_dim_block += num_subgroups * SG_MAT_N) {
527
+ // load O submatrix from shared memory
528
+ var o_sg_mat: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(
529
+ &o_shmem,
530
+ head_dim_block,
531
+ false,
532
+ HEAD_DIM_V
533
+ );
534
+ for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) {
535
+ let p_offset = kv_block * SG_MAT_N;
536
+ var p_sg_mat: subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(
537
+ &inter_shmem,
538
+ p_offset,
539
+ false,
540
+ KV_TILE
541
+ );
542
+
543
+ // load V submatrix from global or shared memory
544
+ #ifdef KV_DIRECT
545
+ let v_block_row = kv_tile + kv_block * SG_MAT_N;
546
+ let v_global_offset = v_head_offset + v_block_row * params.stride_v1 + head_dim_block;
547
+ var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
548
+ &V,
549
+ v_global_offset,
550
+ false,
551
+ params.stride_v1
552
+ );
553
+ #else
554
+ let v_block_offset = kv_block * SG_MAT_N * HEAD_DIM_V;
555
+ var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
556
+ &kv_shmem,
557
+ v_block_offset + head_dim_block,
558
+ false,
559
+ HEAD_DIM_V
560
+ );
561
+ #endif
562
+ // O += P * V
563
+ o_sg_mat = subgroupMatrixMultiplyAccumulate(p_sg_mat, v_sg_mat, o_sg_mat);
564
+ }
565
+ // store O back to shared memory
566
+ subgroupMatrixStore(&o_shmem, head_dim_block, o_sg_mat, false, HEAD_DIM_V);
567
+ }
568
+ workgroupBarrier();
569
+ }
570
+
571
+ #ifdef SINKS
572
+ // add sinks (applied once after processing all KV tiles)
573
+ for (var q_tile_row = subgroup_id;
574
+ q_tile_row < Q_TILE;
575
+ q_tile_row += num_subgroups) {
576
+ // no need to process rows beyond seq_len_q
577
+ let global_q_row = q_row_start + q_tile_row;
578
+ if (global_q_row >= params.seq_len_q) {
579
+ break;
580
+ }
581
+
582
+ var prev_max = row_max_shmem[q_tile_row];
583
+
584
+ // for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum
585
+ let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0);
586
+ let new_max = subgroupMax(max(prev_max, sink_val));
587
+ let max_exp = exp(prev_max - new_max);
588
+ let sink_exp = exp(sink_val - new_max);
589
+
590
+ let sink_exp_sum = subgroupAdd(sink_exp);
591
+
592
+ if (sg_inv_id == 0) {
593
+ exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum;
594
+ }
595
+
596
+ for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
597
+ let idx = q_tile_row * HEAD_DIM_V + elem_idx;
598
+ let val = f32(o_shmem[idx]) * max_exp;
599
+ o_shmem[idx] = f16(val);
600
+ }
601
+ }
602
+ workgroupBarrier();
603
+ #endif
604
+ for (var q_tile_row = subgroup_id;
605
+ q_tile_row < Q_TILE;
606
+ q_tile_row += num_subgroups) {
607
+
608
+ let global_q_row = q_row_start + q_tile_row;
609
+ if (global_q_row >= params.seq_len_q) { break; }
610
+
611
+ let exp_sum = exp_sum_shmem[q_tile_row];
612
+ let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0);
613
+
614
+ let row_base: u32 = dst_global_offset + q_tile_row * dst2_stride;
615
+
616
+ for (var elem_base = sg_inv_id * 4u;
617
+ elem_base < HEAD_DIM_V;
618
+ elem_base += subgroup_size * 4u) {
619
+
620
+ let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
621
+ let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
622
+ let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
623
+ let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
624
+
625
+ let v = vec4<f32>(
626
+ f32(o_shmem[i0]) * scale,
627
+ f32(o_shmem[i1]) * scale,
628
+ f32(o_shmem[i2]) * scale,
629
+ f32(o_shmem[i3]) * scale
630
+ );
631
+
632
+ let dst_vec_index: u32 = (row_base + elem_base) >> 2u;
633
+ dst[dst_vec_index] = v;
634
+ }
635
+ }
636
+ }