local-llm-rn 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (626) hide show
  1. package/cpp/CMakeLists.txt +285 -0
  2. package/cpp/common/CMakeLists.txt +149 -0
  3. package/cpp/common/arg.cpp +3799 -0
  4. package/cpp/common/arg.h +131 -0
  5. package/cpp/common/base64.hpp +392 -0
  6. package/cpp/common/build-info.cpp.in +4 -0
  7. package/cpp/common/chat-parser-xml-toolcall.cpp +879 -0
  8. package/cpp/common/chat-parser-xml-toolcall.h +45 -0
  9. package/cpp/common/chat-parser.cpp +1649 -0
  10. package/cpp/common/chat-parser.h +133 -0
  11. package/cpp/common/chat-peg-parser.cpp +124 -0
  12. package/cpp/common/chat-peg-parser.h +105 -0
  13. package/cpp/common/chat.cpp +3355 -0
  14. package/cpp/common/chat.h +252 -0
  15. package/cpp/common/common.cpp +1824 -0
  16. package/cpp/common/common.h +930 -0
  17. package/cpp/common/console.cpp +1137 -0
  18. package/cpp/common/console.h +41 -0
  19. package/cpp/common/debug.cpp +167 -0
  20. package/cpp/common/debug.h +43 -0
  21. package/cpp/common/download.cpp +792 -0
  22. package/cpp/common/download.h +84 -0
  23. package/cpp/common/http.h +84 -0
  24. package/cpp/common/jinja/README.md +88 -0
  25. package/cpp/common/jinja/caps.cpp +285 -0
  26. package/cpp/common/jinja/caps.h +30 -0
  27. package/cpp/common/jinja/lexer.cpp +341 -0
  28. package/cpp/common/jinja/lexer.h +157 -0
  29. package/cpp/common/jinja/parser.cpp +591 -0
  30. package/cpp/common/jinja/parser.h +21 -0
  31. package/cpp/common/jinja/runtime.cpp +867 -0
  32. package/cpp/common/jinja/runtime.h +638 -0
  33. package/cpp/common/jinja/string.cpp +213 -0
  34. package/cpp/common/jinja/string.h +61 -0
  35. package/cpp/common/jinja/utils.h +149 -0
  36. package/cpp/common/jinja/value.cpp +1393 -0
  37. package/cpp/common/jinja/value.h +756 -0
  38. package/cpp/common/json-partial.cpp +324 -0
  39. package/cpp/common/json-partial.h +39 -0
  40. package/cpp/common/json-schema-to-grammar.cpp +1153 -0
  41. package/cpp/common/json-schema-to-grammar.h +43 -0
  42. package/cpp/common/llguidance.cpp +258 -0
  43. package/cpp/common/log.cpp +446 -0
  44. package/cpp/common/log.h +119 -0
  45. package/cpp/common/ngram-cache.cpp +285 -0
  46. package/cpp/common/ngram-cache.h +101 -0
  47. package/cpp/common/ngram-map.cpp +530 -0
  48. package/cpp/common/ngram-map.h +115 -0
  49. package/cpp/common/ngram-mod.cpp +60 -0
  50. package/cpp/common/ngram-mod.h +38 -0
  51. package/cpp/common/peg-parser.cpp +1712 -0
  52. package/cpp/common/peg-parser.h +459 -0
  53. package/cpp/common/preset.cpp +483 -0
  54. package/cpp/common/preset.h +83 -0
  55. package/cpp/common/regex-partial.cpp +204 -0
  56. package/cpp/common/regex-partial.h +56 -0
  57. package/cpp/common/sampling.cpp +745 -0
  58. package/cpp/common/sampling.h +119 -0
  59. package/cpp/common/speculative.cpp +1074 -0
  60. package/cpp/common/speculative.h +41 -0
  61. package/cpp/common/unicode.cpp +64 -0
  62. package/cpp/common/unicode.h +22 -0
  63. package/cpp/ggml/CMakeLists.txt +494 -0
  64. package/cpp/ggml/cmake/GitVars.cmake +22 -0
  65. package/cpp/ggml/cmake/common.cmake +50 -0
  66. package/cpp/ggml/cmake/ggml-config.cmake.in +191 -0
  67. package/cpp/ggml/include/ggml-alloc.h +85 -0
  68. package/cpp/ggml/include/ggml-backend.h +373 -0
  69. package/cpp/ggml/include/ggml-blas.h +25 -0
  70. package/cpp/ggml/include/ggml-cann.h +123 -0
  71. package/cpp/ggml/include/ggml-cpp.h +39 -0
  72. package/cpp/ggml/include/ggml-cpu.h +151 -0
  73. package/cpp/ggml/include/ggml-cuda.h +47 -0
  74. package/cpp/ggml/include/ggml-hexagon.h +19 -0
  75. package/cpp/ggml/include/ggml-metal.h +61 -0
  76. package/cpp/ggml/include/ggml-opencl.h +26 -0
  77. package/cpp/ggml/include/ggml-opt.h +256 -0
  78. package/cpp/ggml/include/ggml-rpc.h +30 -0
  79. package/cpp/ggml/include/ggml-sycl.h +49 -0
  80. package/cpp/ggml/include/ggml-virtgpu.h +14 -0
  81. package/cpp/ggml/include/ggml-vulkan.h +29 -0
  82. package/cpp/ggml/include/ggml-webgpu.h +19 -0
  83. package/cpp/ggml/include/ggml-zdnn.h +17 -0
  84. package/cpp/ggml/include/ggml-zendnn.h +22 -0
  85. package/cpp/ggml/include/ggml.h +2753 -0
  86. package/cpp/ggml/include/gguf.h +204 -0
  87. package/cpp/ggml/src/CMakeLists.txt +492 -0
  88. package/cpp/ggml/src/ggml-alloc.c +1244 -0
  89. package/cpp/ggml/src/ggml-backend-dl.cpp +48 -0
  90. package/cpp/ggml/src/ggml-backend-dl.h +45 -0
  91. package/cpp/ggml/src/ggml-backend-impl.h +255 -0
  92. package/cpp/ggml/src/ggml-backend-reg.cpp +566 -0
  93. package/cpp/ggml/src/ggml-backend.cpp +2270 -0
  94. package/cpp/ggml/src/ggml-blas/CMakeLists.txt +101 -0
  95. package/cpp/ggml/src/ggml-blas/ggml-blas.cpp +518 -0
  96. package/cpp/ggml/src/ggml-common.h +1878 -0
  97. package/cpp/ggml/src/ggml-cpu/CMakeLists.txt +691 -0
  98. package/cpp/ggml/src/ggml-cpu/amx/amx.cpp +247 -0
  99. package/cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  100. package/cpp/ggml/src/ggml-cpu/amx/common.h +91 -0
  101. package/cpp/ggml/src/ggml-cpu/amx/mmq.cpp +2512 -0
  102. package/cpp/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  103. package/cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +98 -0
  104. package/cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4052 -0
  105. package/cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +4935 -0
  106. package/cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2159 -0
  107. package/cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  108. package/cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
  109. package/cpp/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  110. package/cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2726 -0
  111. package/cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
  112. package/cpp/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  113. package/cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
  114. package/cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
  115. package/cpp/ggml/src/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
  116. package/cpp/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
  117. package/cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
  118. package/cpp/ggml/src/ggml-cpu/arch-fallback.h +313 -0
  119. package/cpp/ggml/src/ggml-cpu/binary-ops.cpp +154 -0
  120. package/cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  121. package/cpp/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
  122. package/cpp/ggml/src/ggml-cpu/common.h +95 -0
  123. package/cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +529 -0
  124. package/cpp/ggml/src/ggml-cpu/ggml-cpu.c +3734 -0
  125. package/cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +701 -0
  126. package/cpp/ggml/src/ggml-cpu/hbm.cpp +55 -0
  127. package/cpp/ggml/src/ggml-cpu/hbm.h +8 -0
  128. package/cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +938 -0
  129. package/cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +90 -0
  130. package/cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +798 -0
  131. package/cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  132. package/cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +4033 -0
  133. package/cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +25 -0
  134. package/cpp/ggml/src/ggml-cpu/ops.cpp +10978 -0
  135. package/cpp/ggml/src/ggml-cpu/ops.h +116 -0
  136. package/cpp/ggml/src/ggml-cpu/quants.c +1193 -0
  137. package/cpp/ggml/src/ggml-cpu/quants.h +97 -0
  138. package/cpp/ggml/src/ggml-cpu/repack.cpp +3316 -0
  139. package/cpp/ggml/src/ggml-cpu/repack.h +173 -0
  140. package/cpp/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  141. package/cpp/ggml/src/ggml-cpu/simd-mappings.h +1279 -0
  142. package/cpp/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
  143. package/cpp/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  144. package/cpp/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  145. package/cpp/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  146. package/cpp/ggml/src/ggml-cpu/traits.cpp +36 -0
  147. package/cpp/ggml/src/ggml-cpu/traits.h +38 -0
  148. package/cpp/ggml/src/ggml-cpu/unary-ops.cpp +337 -0
  149. package/cpp/ggml/src/ggml-cpu/unary-ops.h +35 -0
  150. package/cpp/ggml/src/ggml-cpu/vec.cpp +629 -0
  151. package/cpp/ggml/src/ggml-cpu/vec.h +1585 -0
  152. package/cpp/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
  153. package/cpp/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3232 -0
  154. package/cpp/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -0
  155. package/cpp/ggml/src/ggml-hexagon/htp/act-ops.c +815 -0
  156. package/cpp/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  157. package/cpp/ggml/src/ggml-hexagon/htp/binary-ops.c +827 -0
  158. package/cpp/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  159. package/cpp/ggml/src/ggml-hexagon/htp/cpy-ops.c +251 -0
  160. package/cpp/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +666 -0
  161. package/cpp/ggml/src/ggml-hexagon/htp/get-rows-ops.c +111 -0
  162. package/cpp/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
  163. package/cpp/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
  164. package/cpp/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  165. package/cpp/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  166. package/cpp/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  167. package/cpp/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  168. package/cpp/ggml/src/ggml-hexagon/htp/htp-msg.h +154 -0
  169. package/cpp/ggml/src/ggml-hexagon/htp/htp-ops.h +65 -0
  170. package/cpp/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  171. package/cpp/ggml/src/ggml-hexagon/htp/hvx-arith.h +470 -0
  172. package/cpp/ggml/src/ggml-hexagon/htp/hvx-base.h +173 -0
  173. package/cpp/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  174. package/cpp/ggml/src/ggml-hexagon/htp/hvx-div.h +116 -0
  175. package/cpp/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  176. package/cpp/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  177. package/cpp/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  178. package/cpp/ggml/src/ggml-hexagon/htp/hvx-inverse.h +176 -0
  179. package/cpp/ggml/src/ggml-hexagon/htp/hvx-reduce.h +266 -0
  180. package/cpp/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  181. package/cpp/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  182. package/cpp/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  183. package/cpp/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  184. package/cpp/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -0
  185. package/cpp/ggml/src/ggml-hexagon/htp/main.c +1150 -0
  186. package/cpp/ggml/src/ggml-hexagon/htp/matmul-ops.c +2595 -0
  187. package/cpp/ggml/src/ggml-hexagon/htp/rope-ops.c +498 -0
  188. package/cpp/ggml/src/ggml-hexagon/htp/set-rows-ops.c +167 -0
  189. package/cpp/ggml/src/ggml-hexagon/htp/softmax-ops.c +421 -0
  190. package/cpp/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +130 -0
  191. package/cpp/ggml/src/ggml-hexagon/htp/unary-ops.c +384 -0
  192. package/cpp/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
  193. package/cpp/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  194. package/cpp/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  195. package/cpp/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  196. package/cpp/ggml/src/ggml-hexagon/libdl.h +79 -0
  197. package/cpp/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  198. package/cpp/ggml/src/ggml-hexagon/op-desc.h +153 -0
  199. package/cpp/ggml/src/ggml-impl.h +724 -0
  200. package/cpp/ggml/src/ggml-metal/CMakeLists.txt +124 -0
  201. package/cpp/ggml/src/ggml-metal/ggml-metal-common.cpp +457 -0
  202. package/cpp/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  203. package/cpp/ggml/src/ggml-metal/ggml-metal-context.h +41 -0
  204. package/cpp/ggml/src/ggml-metal/ggml-metal-context.m +702 -0
  205. package/cpp/ggml/src/ggml-metal/ggml-metal-device.cpp +1890 -0
  206. package/cpp/ggml/src/ggml-metal/ggml-metal-device.h +290 -0
  207. package/cpp/ggml/src/ggml-metal/ggml-metal-device.m +1749 -0
  208. package/cpp/ggml/src/ggml-metal/ggml-metal-impl.h +1054 -0
  209. package/cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp +4370 -0
  210. package/cpp/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
  211. package/cpp/ggml/src/ggml-metal/ggml-metal.cpp +937 -0
  212. package/cpp/ggml/src/ggml-metal/ggml-metal.metal +9819 -0
  213. package/cpp/ggml/src/ggml-musa/CMakeLists.txt +125 -0
  214. package/cpp/ggml/src/ggml-musa/mudnn.cu +112 -0
  215. package/cpp/ggml/src/ggml-musa/mudnn.cuh +12 -0
  216. package/cpp/ggml/src/ggml-opencl/CMakeLists.txt +150 -0
  217. package/cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +11553 -0
  218. package/cpp/ggml/src/ggml-opencl/kernels/add.cl +190 -0
  219. package/cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  220. package/cpp/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  221. package/cpp/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
  222. package/cpp/ggml/src/ggml-opencl/kernels/concat.cl +51 -0
  223. package/cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  224. package/cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  225. package/cpp/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
  226. package/cpp/ggml/src/ggml-opencl/kernels/cvt.cl +417 -0
  227. package/cpp/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
  228. package/cpp/ggml/src/ggml-opencl/kernels/div.cl +138 -0
  229. package/cpp/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
  230. package/cpp/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
  231. package/cpp/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  232. package/cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  233. package/cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
  234. package/cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  235. package/cpp/ggml/src/ggml-opencl/kernels/gelu.cl +89 -0
  236. package/cpp/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  237. package/cpp/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  238. package/cpp/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
  239. package/cpp/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
  240. package/cpp/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  241. package/cpp/ggml/src/ggml-opencl/kernels/get_rows.cl +187 -0
  242. package/cpp/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
  243. package/cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
  244. package/cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
  245. package/cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
  246. package/cpp/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
  247. package/cpp/ggml/src/ggml-opencl/kernels/mul.cl +152 -0
  248. package/cpp/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
  249. package/cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  250. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  251. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
  252. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
  253. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  254. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  255. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  256. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  257. package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  258. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
  259. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
  260. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
  261. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
  262. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
  263. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  264. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  265. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  266. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  267. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  268. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  269. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  270. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
  271. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
  272. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
  273. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
  274. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
  275. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  276. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  277. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  278. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32.cl +194 -0
  279. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  280. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  281. package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  282. package/cpp/ggml/src/ggml-opencl/kernels/norm.cl +161 -0
  283. package/cpp/ggml/src/ggml-opencl/kernels/pad.cl +39 -0
  284. package/cpp/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
  285. package/cpp/ggml/src/ggml-opencl/kernels/repeat.cl +38 -0
  286. package/cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +190 -0
  287. package/cpp/ggml/src/ggml-opencl/kernels/rope.cl +747 -0
  288. package/cpp/ggml/src/ggml-opencl/kernels/scale.cl +27 -0
  289. package/cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
  290. package/cpp/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  291. package/cpp/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
  292. package/cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +108 -0
  293. package/cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +108 -0
  294. package/cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +107 -0
  295. package/cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +107 -0
  296. package/cpp/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
  297. package/cpp/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  298. package/cpp/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  299. package/cpp/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  300. package/cpp/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  301. package/cpp/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
  302. package/cpp/ggml/src/ggml-opencl/kernels/sum_rows.cl +140 -0
  303. package/cpp/ggml/src/ggml-opencl/kernels/tanh.cl +109 -0
  304. package/cpp/ggml/src/ggml-opencl/kernels/transpose.cl +117 -0
  305. package/cpp/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  306. package/cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  307. package/cpp/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
  308. package/cpp/ggml/src/ggml-opt.cpp +1093 -0
  309. package/cpp/ggml/src/ggml-quants.c +5325 -0
  310. package/cpp/ggml/src/ggml-quants.h +106 -0
  311. package/cpp/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  312. package/cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2118 -0
  313. package/cpp/ggml/src/ggml-threading.cpp +12 -0
  314. package/cpp/ggml/src/ggml-threading.h +14 -0
  315. package/cpp/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  316. package/cpp/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  317. package/cpp/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  318. package/cpp/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  319. package/cpp/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  320. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  321. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  322. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  323. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  324. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  325. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  326. package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  327. package/cpp/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  328. package/cpp/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  329. package/cpp/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  330. package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  331. package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  332. package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  333. package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  334. package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  335. package/cpp/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  336. package/cpp/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  337. package/cpp/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  338. package/cpp/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  339. package/cpp/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  340. package/cpp/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  341. package/cpp/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  342. package/cpp/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  343. package/cpp/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  344. package/cpp/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  345. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  346. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  347. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  348. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  349. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  350. package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  351. package/cpp/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  352. package/cpp/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  353. package/cpp/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  354. package/cpp/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  355. package/cpp/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  356. package/cpp/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  357. package/cpp/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
  358. package/cpp/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1231 -0
  359. package/cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3150 -0
  360. package/cpp/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  361. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  362. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  363. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  364. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +107 -0
  365. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +923 -0
  366. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
  367. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  368. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +182 -0
  369. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
  370. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +668 -0
  371. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  372. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  373. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +713 -0
  374. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +103 -0
  375. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +138 -0
  376. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +188 -0
  377. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +194 -0
  378. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  379. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
  380. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  381. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
  382. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +109 -0
  383. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  384. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  385. package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  386. package/cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  387. package/cpp/ggml/src/ggml-zdnn/common.hpp +59 -0
  388. package/cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +633 -0
  389. package/cpp/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  390. package/cpp/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  391. package/cpp/ggml/src/ggml-zdnn/utils.cpp +79 -0
  392. package/cpp/ggml/src/ggml-zdnn/utils.hpp +19 -0
  393. package/cpp/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  394. package/cpp/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
  395. package/cpp/ggml/src/ggml.c +7669 -0
  396. package/cpp/ggml/src/ggml.cpp +26 -0
  397. package/cpp/ggml/src/gguf.cpp +1699 -0
  398. package/cpp/include/llama-cpp.h +32 -0
  399. package/cpp/include/llama.h +1568 -0
  400. package/cpp/mtmd/CMakeLists.txt +98 -0
  401. package/cpp/mtmd/README.md +63 -0
  402. package/cpp/mtmd/clip-graph.h +117 -0
  403. package/cpp/mtmd/clip-impl.h +586 -0
  404. package/cpp/mtmd/clip-model.h +390 -0
  405. package/cpp/mtmd/clip.cpp +4154 -0
  406. package/cpp/mtmd/clip.h +121 -0
  407. package/cpp/mtmd/deprecation-warning.cpp +22 -0
  408. package/cpp/mtmd/legacy-models/convert_image_encoder_to_gguf.py +412 -0
  409. package/cpp/mtmd/legacy-models/glmedge-convert-image-encoder-to-gguf.py +280 -0
  410. package/cpp/mtmd/legacy-models/glmedge-surgery.py +33 -0
  411. package/cpp/mtmd/legacy-models/llava_surgery.py +38 -0
  412. package/cpp/mtmd/legacy-models/llava_surgery_v2.py +180 -0
  413. package/cpp/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py +892 -0
  414. package/cpp/mtmd/legacy-models/minicpmv-surgery.py +47 -0
  415. package/cpp/mtmd/models/cogvlm.cpp +98 -0
  416. package/cpp/mtmd/models/conformer.cpp +216 -0
  417. package/cpp/mtmd/models/glm4v.cpp +122 -0
  418. package/cpp/mtmd/models/internvl.cpp +69 -0
  419. package/cpp/mtmd/models/kimik25.cpp +101 -0
  420. package/cpp/mtmd/models/kimivl.cpp +63 -0
  421. package/cpp/mtmd/models/llama4.cpp +96 -0
  422. package/cpp/mtmd/models/llava.cpp +374 -0
  423. package/cpp/mtmd/models/minicpmv.cpp +114 -0
  424. package/cpp/mtmd/models/mobilenetv5.cpp +451 -0
  425. package/cpp/mtmd/models/models.h +128 -0
  426. package/cpp/mtmd/models/nemotron-v2-vl.cpp +35 -0
  427. package/cpp/mtmd/models/paddleocr.cpp +52 -0
  428. package/cpp/mtmd/models/pixtral.cpp +86 -0
  429. package/cpp/mtmd/models/qwen2vl.cpp +183 -0
  430. package/cpp/mtmd/models/qwen3vl.cpp +193 -0
  431. package/cpp/mtmd/models/siglip.cpp +86 -0
  432. package/cpp/mtmd/models/whisper-enc.cpp +115 -0
  433. package/cpp/mtmd/models/youtuvl.cpp +179 -0
  434. package/cpp/mtmd/mtmd-audio.cpp +730 -0
  435. package/cpp/mtmd/mtmd-audio.h +113 -0
  436. package/cpp/mtmd/mtmd-cli.cpp +437 -0
  437. package/cpp/mtmd/mtmd-helper.cpp +521 -0
  438. package/cpp/mtmd/mtmd-helper.h +96 -0
  439. package/cpp/mtmd/mtmd.cpp +1156 -0
  440. package/cpp/mtmd/mtmd.h +319 -0
  441. package/cpp/mtmd/requirements.txt +5 -0
  442. package/cpp/mtmd/test-1.jpeg +0 -0
  443. package/cpp/mtmd/test-2.mp3 +0 -0
  444. package/cpp/mtmd/tests.sh +192 -0
  445. package/cpp/src/CMakeLists.txt +169 -0
  446. package/cpp/src/llama-adapter.cpp +488 -0
  447. package/cpp/src/llama-adapter.h +89 -0
  448. package/cpp/src/llama-arch.cpp +2855 -0
  449. package/cpp/src/llama-arch.h +619 -0
  450. package/cpp/src/llama-batch.cpp +917 -0
  451. package/cpp/src/llama-batch.h +173 -0
  452. package/cpp/src/llama-chat.cpp +896 -0
  453. package/cpp/src/llama-chat.h +71 -0
  454. package/cpp/src/llama-context.cpp +3512 -0
  455. package/cpp/src/llama-context.h +359 -0
  456. package/cpp/src/llama-cparams.cpp +5 -0
  457. package/cpp/src/llama-cparams.h +44 -0
  458. package/cpp/src/llama-grammar.cpp +1464 -0
  459. package/cpp/src/llama-grammar.h +194 -0
  460. package/cpp/src/llama-graph.cpp +2685 -0
  461. package/cpp/src/llama-graph.h +1026 -0
  462. package/cpp/src/llama-hparams.cpp +234 -0
  463. package/cpp/src/llama-hparams.h +339 -0
  464. package/cpp/src/llama-impl.cpp +171 -0
  465. package/cpp/src/llama-impl.h +73 -0
  466. package/cpp/src/llama-io.cpp +15 -0
  467. package/cpp/src/llama-io.h +35 -0
  468. package/cpp/src/llama-kv-cache-iswa.cpp +330 -0
  469. package/cpp/src/llama-kv-cache-iswa.h +137 -0
  470. package/cpp/src/llama-kv-cache.cpp +2271 -0
  471. package/cpp/src/llama-kv-cache.h +388 -0
  472. package/cpp/src/llama-kv-cells.h +533 -0
  473. package/cpp/src/llama-memory-hybrid-iswa.cpp +275 -0
  474. package/cpp/src/llama-memory-hybrid-iswa.h +140 -0
  475. package/cpp/src/llama-memory-hybrid.cpp +268 -0
  476. package/cpp/src/llama-memory-hybrid.h +139 -0
  477. package/cpp/src/llama-memory-recurrent.cpp +1165 -0
  478. package/cpp/src/llama-memory-recurrent.h +182 -0
  479. package/cpp/src/llama-memory.cpp +59 -0
  480. package/cpp/src/llama-memory.h +122 -0
  481. package/cpp/src/llama-mmap.cpp +785 -0
  482. package/cpp/src/llama-mmap.h +92 -0
  483. package/cpp/src/llama-model-loader.cpp +1414 -0
  484. package/cpp/src/llama-model-loader.h +203 -0
  485. package/cpp/src/llama-model-saver.cpp +286 -0
  486. package/cpp/src/llama-model-saver.h +37 -0
  487. package/cpp/src/llama-model.cpp +9253 -0
  488. package/cpp/src/llama-model.h +576 -0
  489. package/cpp/src/llama-quant.cpp +1119 -0
  490. package/cpp/src/llama-quant.h +1 -0
  491. package/cpp/src/llama-sampler.cpp +3885 -0
  492. package/cpp/src/llama-sampler.h +42 -0
  493. package/cpp/src/llama-vocab.cpp +3970 -0
  494. package/cpp/src/llama-vocab.h +187 -0
  495. package/cpp/src/llama.cpp +1313 -0
  496. package/cpp/src/models/afmoe.cpp +191 -0
  497. package/cpp/src/models/apertus.cpp +125 -0
  498. package/cpp/src/models/arcee.cpp +135 -0
  499. package/cpp/src/models/arctic.cpp +138 -0
  500. package/cpp/src/models/arwkv7.cpp +86 -0
  501. package/cpp/src/models/baichuan.cpp +122 -0
  502. package/cpp/src/models/bailingmoe.cpp +144 -0
  503. package/cpp/src/models/bailingmoe2.cpp +135 -0
  504. package/cpp/src/models/bert.cpp +178 -0
  505. package/cpp/src/models/bitnet.cpp +160 -0
  506. package/cpp/src/models/bloom.cpp +101 -0
  507. package/cpp/src/models/chameleon.cpp +178 -0
  508. package/cpp/src/models/chatglm.cpp +132 -0
  509. package/cpp/src/models/codeshell.cpp +111 -0
  510. package/cpp/src/models/cogvlm.cpp +102 -0
  511. package/cpp/src/models/cohere2-iswa.cpp +134 -0
  512. package/cpp/src/models/command-r.cpp +122 -0
  513. package/cpp/src/models/dbrx.cpp +123 -0
  514. package/cpp/src/models/deci.cpp +135 -0
  515. package/cpp/src/models/deepseek.cpp +144 -0
  516. package/cpp/src/models/deepseek2.cpp +262 -0
  517. package/cpp/src/models/delta-net-base.cpp +376 -0
  518. package/cpp/src/models/dots1.cpp +134 -0
  519. package/cpp/src/models/dream.cpp +105 -0
  520. package/cpp/src/models/ernie4-5-moe.cpp +150 -0
  521. package/cpp/src/models/ernie4-5.cpp +110 -0
  522. package/cpp/src/models/eurobert.cpp +97 -0
  523. package/cpp/src/models/exaone-moe.cpp +146 -0
  524. package/cpp/src/models/exaone.cpp +114 -0
  525. package/cpp/src/models/exaone4.cpp +123 -0
  526. package/cpp/src/models/falcon-h1.cpp +111 -0
  527. package/cpp/src/models/falcon.cpp +120 -0
  528. package/cpp/src/models/gemma-embedding.cpp +116 -0
  529. package/cpp/src/models/gemma.cpp +112 -0
  530. package/cpp/src/models/gemma2-iswa.cpp +128 -0
  531. package/cpp/src/models/gemma3.cpp +155 -0
  532. package/cpp/src/models/gemma3n-iswa.cpp +384 -0
  533. package/cpp/src/models/glm4-moe.cpp +170 -0
  534. package/cpp/src/models/glm4.cpp +157 -0
  535. package/cpp/src/models/gpt2.cpp +105 -0
  536. package/cpp/src/models/gptneox.cpp +144 -0
  537. package/cpp/src/models/granite-hybrid.cpp +196 -0
  538. package/cpp/src/models/granite.cpp +211 -0
  539. package/cpp/src/models/grok.cpp +159 -0
  540. package/cpp/src/models/grovemoe.cpp +141 -0
  541. package/cpp/src/models/hunyuan-dense.cpp +132 -0
  542. package/cpp/src/models/hunyuan-moe.cpp +154 -0
  543. package/cpp/src/models/internlm2.cpp +120 -0
  544. package/cpp/src/models/jais.cpp +86 -0
  545. package/cpp/src/models/jais2.cpp +123 -0
  546. package/cpp/src/models/jamba.cpp +106 -0
  547. package/cpp/src/models/kimi-linear.cpp +392 -0
  548. package/cpp/src/models/lfm2.cpp +190 -0
  549. package/cpp/src/models/llada-moe.cpp +122 -0
  550. package/cpp/src/models/llada.cpp +99 -0
  551. package/cpp/src/models/llama-iswa.cpp +178 -0
  552. package/cpp/src/models/llama.cpp +168 -0
  553. package/cpp/src/models/maincoder.cpp +117 -0
  554. package/cpp/src/models/mamba-base.cpp +285 -0
  555. package/cpp/src/models/mamba.cpp +54 -0
  556. package/cpp/src/models/mimo2-iswa.cpp +123 -0
  557. package/cpp/src/models/minicpm3.cpp +200 -0
  558. package/cpp/src/models/minimax-m2.cpp +124 -0
  559. package/cpp/src/models/mistral3.cpp +160 -0
  560. package/cpp/src/models/models.h +684 -0
  561. package/cpp/src/models/modern-bert.cpp +109 -0
  562. package/cpp/src/models/mpt.cpp +126 -0
  563. package/cpp/src/models/nemotron-h.cpp +148 -0
  564. package/cpp/src/models/nemotron.cpp +122 -0
  565. package/cpp/src/models/neo-bert.cpp +104 -0
  566. package/cpp/src/models/olmo.cpp +121 -0
  567. package/cpp/src/models/olmo2.cpp +150 -0
  568. package/cpp/src/models/olmoe.cpp +124 -0
  569. package/cpp/src/models/openai-moe-iswa.cpp +127 -0
  570. package/cpp/src/models/openelm.cpp +124 -0
  571. package/cpp/src/models/orion.cpp +123 -0
  572. package/cpp/src/models/paddleocr.cpp +122 -0
  573. package/cpp/src/models/pangu-embedded.cpp +121 -0
  574. package/cpp/src/models/phi2.cpp +121 -0
  575. package/cpp/src/models/phi3.cpp +152 -0
  576. package/cpp/src/models/plamo.cpp +110 -0
  577. package/cpp/src/models/plamo2.cpp +318 -0
  578. package/cpp/src/models/plamo3.cpp +128 -0
  579. package/cpp/src/models/plm.cpp +169 -0
  580. package/cpp/src/models/qwen.cpp +108 -0
  581. package/cpp/src/models/qwen2.cpp +126 -0
  582. package/cpp/src/models/qwen2moe.cpp +151 -0
  583. package/cpp/src/models/qwen2vl.cpp +117 -0
  584. package/cpp/src/models/qwen3.cpp +117 -0
  585. package/cpp/src/models/qwen35.cpp +386 -0
  586. package/cpp/src/models/qwen35moe.cpp +420 -0
  587. package/cpp/src/models/qwen3moe.cpp +124 -0
  588. package/cpp/src/models/qwen3next.cpp +525 -0
  589. package/cpp/src/models/qwen3vl-moe.cpp +140 -0
  590. package/cpp/src/models/qwen3vl.cpp +132 -0
  591. package/cpp/src/models/refact.cpp +94 -0
  592. package/cpp/src/models/rnd1.cpp +126 -0
  593. package/cpp/src/models/rwkv6-base.cpp +164 -0
  594. package/cpp/src/models/rwkv6.cpp +94 -0
  595. package/cpp/src/models/rwkv6qwen2.cpp +86 -0
  596. package/cpp/src/models/rwkv7-base.cpp +137 -0
  597. package/cpp/src/models/rwkv7.cpp +90 -0
  598. package/cpp/src/models/seed-oss.cpp +124 -0
  599. package/cpp/src/models/smallthinker.cpp +126 -0
  600. package/cpp/src/models/smollm3.cpp +128 -0
  601. package/cpp/src/models/stablelm.cpp +146 -0
  602. package/cpp/src/models/starcoder.cpp +100 -0
  603. package/cpp/src/models/starcoder2.cpp +121 -0
  604. package/cpp/src/models/step35-iswa.cpp +168 -0
  605. package/cpp/src/models/t5-dec.cpp +166 -0
  606. package/cpp/src/models/t5-enc.cpp +96 -0
  607. package/cpp/src/models/wavtokenizer-dec.cpp +149 -0
  608. package/cpp/src/models/xverse.cpp +108 -0
  609. package/cpp/src/unicode-data.cpp +7034 -0
  610. package/cpp/src/unicode-data.h +20 -0
  611. package/cpp/src/unicode.cpp +1103 -0
  612. package/cpp/src/unicode.h +111 -0
  613. package/cpp/vendor/nlohmann/json.hpp +25526 -0
  614. package/cpp/vendor/nlohmann/json_fwd.hpp +187 -0
  615. package/cpp/vendor/stb/stb_image.h +7988 -0
  616. package/ios/LocalLLM-Bridging-Header.h +2 -0
  617. package/ios/LocalLLM.h +5 -0
  618. package/ios/LocalLLM.mm +1267 -0
  619. package/local-llm-rn.podspec +60 -0
  620. package/package.json +35 -0
  621. package/src/NativeLocalLLM.ts +73 -0
  622. package/src/device.ts +50 -0
  623. package/src/download-adapter.ts +17 -0
  624. package/src/index.ts +21 -0
  625. package/src/native-bridge.ts +142 -0
  626. package/src/rn-downloader.ts +37 -0
@@ -0,0 +1,103 @@
1
+ #ifdef VEC
2
+ #define VEC_SIZE 4
3
+ #define SHMEM_TYPE vec4<f16>
4
+ #define DST_TYPE vec4<f32>
5
+ #define SRC0_TYPE vec4<SRC0_INNER_TYPE>
6
+ #define SRC1_TYPE vec4<SRC1_INNER_TYPE>
7
+
8
+ fn store_shmem(val: vec4<f16>, idx: u32) {
9
+ shmem[idx] = val.x;
10
+ shmem[idx + 1] = val.y;
11
+ shmem[idx + 2] = val.z;
12
+ shmem[idx + 3] = val.w;
13
+ }
14
+ #endif
15
+
16
+ #ifdef SCALAR
17
+ #define VEC_SIZE 1
18
+ #define SHMEM_TYPE f16
19
+ #define DST_TYPE f32
20
+ #define SRC0_TYPE SRC0_INNER_TYPE
21
+ #define SRC1_TYPE SRC1_INNER_TYPE
22
+
23
+ fn store_shmem(val: f16, idx: u32) {
24
+ shmem[idx] = val;
25
+ }
26
+ #endif
27
+
28
+ #ifdef INIT_SRC0_SHMEM_FLOAT
29
+ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
30
+ for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
31
+ let tile_m = elem_idx / TILE_K;
32
+ let tile_k = elem_idx % TILE_K;
33
+ let global_m = offset_m + tile_m;
34
+ let global_k = k_outer + tile_k;
35
+ let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
36
+ let src0_val = select( // taking a slight performance hit to avoid oob
37
+ SRC0_TYPE(0.0),
38
+ src0[src0_idx/VEC_SIZE],
39
+ global_m < params.m && global_k < params.k);
40
+ store_shmem(SHMEM_TYPE(src0_val), elem_idx);
41
+ }
42
+ }
43
+ #endif
44
+
45
+ #ifdef INIT_SRC1_SHMEM_FLOAT
46
+ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
47
+ for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
48
+ let tile_n = elem_idx / TILE_K;
49
+ let tile_k = elem_idx % TILE_K;
50
+ let global_n = offset_n + tile_n;
51
+ let global_k = k_outer + tile_k;
52
+ let src1_idx = batch_offset + global_n * params.stride_11 + global_k;
53
+ let src1_val = select(
54
+ SRC1_TYPE(0.0),
55
+ src1[src1_idx/VEC_SIZE],
56
+ global_n < params.n && global_k < params.k);
57
+ store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx);
58
+ }
59
+ }
60
+ #endif
61
+
62
+ #ifdef INIT_SRC0_SHMEM_Q4_0
63
+ const BLOCK_SIZE = 32u;
64
+ // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
65
+ override BLOCKS_K = TILE_K/BLOCK_SIZE;
66
+ const NQ = 16u;
67
+ const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
68
+ const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
69
+ const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
70
+
71
+ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
72
+ for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
73
+ let blck_idx = i / BLOCK_SIZE;
74
+ let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
75
+ let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
76
+
77
+ let tile_m = blck_idx / BLOCKS_K;
78
+ let global_m = offset_m + tile_m;
79
+ let block_k = blck_idx % BLOCKS_K;
80
+ let global_k = k_outer / BLOCK_SIZE + block_k;
81
+
82
+ if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
83
+ let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
84
+ let scale_idx = src0_idx * F16_PER_BLOCK;
85
+ let d = src0[scale_idx];
86
+
87
+ for (var j = 0u; j < F16_PER_THREAD; j += 2) {
88
+ let q_0 = src0[scale_idx + 1u + block_offset + j];
89
+ let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
90
+
91
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
92
+ for (var k = 0u; k < 4u; k++) {
93
+ let q_byte = get_byte(q_packed, k);
94
+ let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
95
+ let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
96
+ shmem[shmem_idx + j * 2 + k] = q_lo;
97
+ shmem[shmem_idx + j * 2 + k + 16u] = q_hi;
98
+ }
99
+ }
100
+ }
101
+ }
102
+ }
103
+ #endif
@@ -0,0 +1,138 @@
1
+ enable f16;
2
+
3
+ #include "common_decls.tmpl"
4
+ #include "mul_mat_decls.tmpl"
5
+
6
+ #ifdef VEC
7
+ fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> vec4<f32> {
8
+ return vec4<f32>(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn]));
9
+ }
10
+ #endif
11
+
12
+ #ifdef SCALAR
13
+ fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> f32 {
14
+ return f32(acc[tm][tn]);
15
+ }
16
+ #endif
17
+
18
+ struct MulMatParams {
19
+ offset_src0: u32,
20
+ offset_src1: u32,
21
+ offset_dst: u32,
22
+ m: u32,
23
+ n: u32,
24
+ k: u32,
25
+ stride_01: u32,
26
+ stride_11: u32,
27
+ stride_02: u32,
28
+ stride_12: u32,
29
+ stride_03: u32,
30
+ stride_13: u32,
31
+ bs02: u32,
32
+ bs03: u32,
33
+ broadcast2: u32,
34
+ broadcast3: u32
35
+ };
36
+
37
+ @group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
38
+ @group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
39
+ @group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed)
40
+
41
+ @group(0) @binding(3) var<uniform> params: MulMatParams;
42
+
43
+ fn get_local_n(thread_id: u32) -> u32 {
44
+ return thread_id / WORKGROUP_SIZE_M;
45
+ }
46
+ fn get_local_m(thread_id: u32) -> u32 {
47
+ return thread_id % WORKGROUP_SIZE_M;
48
+ }
49
+
50
+ const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
51
+ const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
52
+ const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
53
+ var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;
54
+
55
+ @compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
56
+ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
57
+ @builtin(local_invocation_id) local_id: vec3<u32>) {
58
+
59
+ let thread_id = local_id.x;
60
+ let local_m = get_local_m(thread_id);
61
+ let local_n = get_local_n(thread_id);
62
+
63
+ let wg_n_count = (params.n + WORKGROUP_SIZE_N * TILE_N - 1u) / (WORKGROUP_SIZE_N * TILE_N);
64
+ let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M);
65
+ let wg_per_matrix = wg_m_count * wg_n_count;
66
+
67
+ let batch_idx = wg_id.x / wg_per_matrix;
68
+
69
+ let wg_in_batch = wg_id.x % wg_per_matrix;
70
+ let wg_m = wg_in_batch % wg_m_count;
71
+ let wg_n = wg_in_batch / wg_m_count;
72
+
73
+ let output_row_base = wg_m * WORKGROUP_SIZE_M * TILE_M + local_m * TILE_M;
74
+ let output_col_base = wg_n * WORKGROUP_SIZE_N * TILE_N + local_n * TILE_N;
75
+
76
+ let dst2_stride = params.m * params.n;
77
+ let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
78
+
79
+ let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);
80
+ let src03_idx = dst3_idx / params.broadcast3;
81
+ let src13_idx = dst3_idx;
82
+ let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);
83
+ let src02_idx = dst2_idx / params.broadcast2;
84
+ let src12_idx = dst2_idx;
85
+
86
+ let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02;
87
+ let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
88
+
89
+ let offset_m = wg_m * WORKGROUP_SIZE_M * TILE_M;
90
+ let offset_n = wg_n * WORKGROUP_SIZE_N * TILE_N;
91
+
92
+ var acc: array<array<f16, TILE_N>, TILE_M>;
93
+
94
+ for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) {
95
+
96
+ // see mul_mat_decls.tmpl
97
+ init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer);
98
+ init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer);
99
+
100
+ workgroupBarrier();
101
+
102
+ let k_end = min(TILE_K, params.k - k_outer);
103
+
104
+ for (var k_inner = 0u; k_inner < k_end; k_inner++) {
105
+ var src0_tile: array<f16, TILE_M>;
106
+ for (var tm = 0u; tm < TILE_M; tm++) {
107
+ let src0_m = local_m * TILE_M + tm;
108
+ let src0_idx = k_inner + src0_m * TILE_K;
109
+ src0_tile[tm] = shmem[src0_idx];
110
+ }
111
+ for (var tn = 0u; tn < TILE_N; tn++) {
112
+ let src1_n = local_n * TILE_N + tn;
113
+ let src1_idx = src1_n * TILE_K + k_inner;
114
+ let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx];
115
+ for (var tm = 0u; tm < TILE_M; tm++) {
116
+ acc[tm][tn] += src0_tile[tm] * src1_val;
117
+ }
118
+ }
119
+ }
120
+
121
+ workgroupBarrier();
122
+ }
123
+
124
+ let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride;
125
+
126
+ for (var tn = 0u; tn < TILE_N; tn++) {
127
+ let global_col = output_col_base + tn;
128
+ if (global_col < params.n) {
129
+ for (var tm = 0u; tm < TILE_M; tm += VEC_SIZE) {
130
+ let global_row = output_row_base + tm;
131
+ if (global_row < params.m) {
132
+ let dst_idx = dst_batch_offset + global_col * params.m + global_row;
133
+ dst[dst_idx/VEC_SIZE] = store_val(acc, tn, tm);
134
+ }
135
+ }
136
+ }
137
+ }
138
+ }
@@ -0,0 +1,188 @@
1
+ diagnostic(off, chromium.subgroup_matrix_uniformity);
2
+ enable f16;
3
+ enable subgroups;
4
+ enable chromium_experimental_subgroup_matrix;
5
+
6
+ #include "common_decls.tmpl"
7
+ #include "mul_mat_decls.tmpl"
8
+
9
+ #ifdef VEC
10
+ fn store_dst(shmem_idx: u32, dst_idx: u32) {
11
+ dst[dst_idx] = vec4<f32>(
12
+ f32(shmem[shmem_idx]),
13
+ f32(shmem[shmem_idx + 1]),
14
+ f32(shmem[shmem_idx + 2]),
15
+ f32(shmem[shmem_idx + 3])
16
+ );
17
+ }
18
+ #endif
19
+
20
+ #ifdef SCALAR
21
+ fn store_dst(shmem_idx: u32, dst_idx: u32) {
22
+ dst[dst_idx] = f32(shmem[shmem_idx]);
23
+ }
24
+ #endif
25
+
26
+ struct MulMatParams {
27
+ offset_src0: u32,
28
+ offset_src1: u32,
29
+ offset_dst: u32,
30
+ m: u32,
31
+ n: u32,
32
+ k: u32,
33
+ stride_01: u32,
34
+ stride_11: u32,
35
+ stride_02: u32,
36
+ stride_12: u32,
37
+ stride_03: u32,
38
+ stride_13: u32,
39
+ bs02: u32,
40
+ bs03: u32,
41
+ broadcast2: u32,
42
+ broadcast3: u32
43
+ };
44
+
45
+ // SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included
46
+ @group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
47
+ @group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
48
+ @group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed)
49
+
50
+ @group(0) @binding(3) var<uniform> params: MulMatParams;
51
+
52
+ const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
53
+ const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
54
+
55
+ // For portability we assume the max subgroup size, meaning some subgroups will be masked out if the
56
+ // runtime subgroup size is smaller.
57
+ const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N;
58
+ const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE;
59
+ const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
60
+ const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
61
+
62
+ const SG_MAT_ACCUM_SHMEM = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_M_SIZE * SUBGROUP_MATRIX_N_SIZE;
63
+
64
+ // We reuse shmem for accumulation matrices
65
+ const SHMEM_SIZE = max(TILE_SRC0_SHMEM + TILE_SRC1_SHMEM, SG_MAT_ACCUM_SHMEM);
66
+
67
+ var<workgroup> shmem: array<f16, SHMEM_SIZE>;
68
+
69
+ @compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
70
+ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
71
+ @builtin(local_invocation_id) local_id: vec3<u32>,
72
+ @builtin(subgroup_id) subgroup_id: u32) {
73
+
74
+ let thread_id = local_id.x;
75
+ let subgroup_m = subgroup_id % SUBGROUP_M;
76
+ let subgroup_n = subgroup_id / SUBGROUP_M;
77
+
78
+ let wg_m_count = (params.m + WG_M_SG_TILE_SIZE - 1) / WG_M_SG_TILE_SIZE;
79
+ let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE;
80
+ let wg_per_matrix = wg_m_count * wg_n_count;
81
+
82
+ let batch_idx = wg_id.x / wg_per_matrix;
83
+
84
+ let wg_in_batch = wg_id.x % wg_per_matrix;
85
+ let wg_m = wg_in_batch % wg_m_count;
86
+ let wg_n = wg_in_batch / wg_m_count;
87
+
88
+ let dst2_stride = params.m * params.n;
89
+ let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
90
+
91
+ let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);
92
+ let src03_idx = dst3_idx / params.broadcast3;
93
+ let src13_idx = dst3_idx;
94
+ let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);
95
+ let src02_idx = dst2_idx / params.broadcast2;
96
+ let src12_idx = dst2_idx;
97
+
98
+ let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02;
99
+ let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
100
+
101
+ let offset_m = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
102
+ let offset_n = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
103
+
104
+ var acc_sg_mat : array<array<subgroup_matrix_result<f16, SUBGROUP_MATRIX_N_SIZE, SUBGROUP_MATRIX_M_SIZE>, SUBGROUP_MATRIX_N>, SUBGROUP_MATRIX_M>;
105
+
106
+ for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) {
107
+
108
+ // see mul_mat_decls.tmpl
109
+ init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer);
110
+ init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer);
111
+
112
+ workgroupBarrier();
113
+
114
+ if (subgroup_id < EXPECTED_SUBGROUPS) {
115
+
116
+ for (var k_inner = 0u; k_inner < TILE_K; k_inner += SUBGROUP_MATRIX_K_SIZE) {
117
+
118
+ let src0_shmem_idx_base = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE * TILE_K + k_inner;
119
+ var src0_sg_mats: array<subgroup_matrix_left<f16, SUBGROUP_MATRIX_K_SIZE, SUBGROUP_MATRIX_M_SIZE>, SUBGROUP_MATRIX_M>;
120
+ for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
121
+ src0_sg_mats[m] = subgroupMatrixLoad<subgroup_matrix_left<f16, SUBGROUP_MATRIX_K_SIZE, SUBGROUP_MATRIX_M_SIZE>>(
122
+ &shmem,
123
+ src0_shmem_idx_base + m * SUBGROUP_MATRIX_M_SIZE * TILE_K,
124
+ false,
125
+ TILE_K
126
+ );
127
+ }
128
+
129
+ let src1_shmem_idx_base = TILE_SRC0_SHMEM + subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE * TILE_K + k_inner;
130
+ for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) {
131
+ let src1_sg_mat = subgroupMatrixLoad<subgroup_matrix_right<f16, SUBGROUP_MATRIX_N_SIZE, SUBGROUP_MATRIX_K_SIZE>>(
132
+ &shmem,
133
+ src1_shmem_idx_base + n * SUBGROUP_MATRIX_N_SIZE * TILE_K,
134
+ true,
135
+ TILE_K
136
+ );
137
+ for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
138
+ acc_sg_mat[m][n] = subgroupMatrixMultiplyAccumulate(src0_sg_mats[m], src1_sg_mat, acc_sg_mat[m][n]);
139
+ }
140
+ }
141
+ }
142
+ }
143
+
144
+ workgroupBarrier();
145
+ }
146
+
147
+ let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride;
148
+
149
+ // Stage the subgroup matrix tiles into shared memory
150
+ // This uses WG_M_SG_TILE_SIZE as the stride (number of columns in the workgroup tile).
151
+ let WG_TILE_STRIDE = WG_M_SG_TILE_SIZE;
152
+ let tile_row_base_local = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
153
+ let tile_col_base_local = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
154
+
155
+ if (subgroup_id < EXPECTED_SUBGROUPS) { // 2-5% performance hit :(
156
+ for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) {
157
+ for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
158
+ let local_row = tile_row_base_local + n * SUBGROUP_MATRIX_N_SIZE;
159
+ let local_col = tile_col_base_local + m * SUBGROUP_MATRIX_M_SIZE;
160
+ let out_base = local_row * WG_TILE_STRIDE + local_col;
161
+ subgroupMatrixStore(&shmem, out_base, acc_sg_mat[m][n], true, WG_TILE_STRIDE);
162
+ }
163
+ }
164
+ }
165
+
166
+ workgroupBarrier();
167
+
168
+ // Cooperative write: iterate over the entire workgroup tile
169
+ let tile_rows = WG_N_SG_TILE_SIZE;
170
+ let tile_cols = WG_M_SG_TILE_SIZE;
171
+ let total_tile_elems = tile_rows * tile_cols;
172
+ let tile_dst_row_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
173
+ let tile_dst_col_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
174
+
175
+ for (var idx = thread_id * VEC_SIZE; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
176
+ let local_row = idx % WG_TILE_STRIDE;
177
+ let local_col = idx / WG_TILE_STRIDE;
178
+
179
+ let global_row = tile_dst_row_base + local_row;
180
+ let global_col = tile_dst_col_base + local_col;
181
+
182
+ if (global_col < params.n && global_row < params.m) {
183
+ let dst_idx = dst_batch_offset + global_col * params.m + global_row;
184
+ store_dst(idx, dst_idx/VEC_SIZE);
185
+ }
186
+ }
187
+ }
188
+
@@ -0,0 +1,194 @@
1
+
2
+ enable f16;
3
+
4
+ #include "common_decls.tmpl"
5
+
6
+ #ifdef VEC
7
+
8
+ #define VEC_SIZE 4
9
+ #define DST_TYPE vec4<f32>
10
+ #define SRC0_TYPE vec4<SRC0_INNER_TYPE>
11
+ #define SRC1_TYPE vec4<SRC1_INNER_TYPE>
12
+
13
+ fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {
14
+ return f32(dot(SRC1_TYPE(src0_val), src1_val));
15
+ }
16
+
17
+ fn store_val(group_base: u32) -> vec4<f32> {
18
+ return vec4<f32>(partial_sums[group_base],
19
+ partial_sums[group_base + THREADS_PER_OUTPUT],
20
+ partial_sums[group_base + THREADS_PER_OUTPUT * 2],
21
+ partial_sums[group_base + THREADS_PER_OUTPUT * 3]);
22
+ }
23
+ #endif
24
+
25
+ #ifdef SCALAR
26
+
27
+ #define VEC_SIZE 1
28
+ #define DST_TYPE f32
29
+ #define SRC0_TYPE SRC0_INNER_TYPE
30
+ #define SRC1_TYPE SRC1_INNER_TYPE
31
+
32
+ fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {
33
+ return f32(src0_val) * f32(src1_val);
34
+ }
35
+
36
+ fn store_val(group_base: u32) -> f32 {
37
+ return partial_sums[group_base];
38
+ }
39
+ #endif
40
+
41
+ #ifdef MUL_ACC_FLOAT
42
+ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
43
+ var local_sum = 0.0;
44
+ for (var i = tig * VEC_SIZE; i < tile_size; i += THREADS_PER_OUTPUT * VEC_SIZE) {
45
+ let a = src0[(idx_base + k_outer + i) / VEC_SIZE];
46
+ let b = shared_vector[i / VEC_SIZE];
47
+ local_sum += inner_dot(a, b);
48
+ }
49
+ return local_sum;
50
+ }
51
+ #endif
52
+
53
+ #ifdef MUL_ACC_Q4_0
54
+
55
+ const BLOCK_SIZE = 32;
56
+ const NQ = 16u; // number of weights per thread
57
+ const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
58
+ const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
59
+ const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
60
+
61
+ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
62
+ var local_sum = 0.0;
63
+ for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
64
+ let blck_idx = i / BLOCK_SIZE;
65
+ let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
66
+ let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
67
+ // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
68
+ let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
69
+ let d = f32(src0[scale_idx]);
70
+ for (var j = 0u; j < F16_PER_THREAD; j += 2) {
71
+ let q_0 = src0[scale_idx + 1 + block_offset + j];
72
+ let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
73
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
74
+ for (var k: u32 = 0; k < 4; k++) {
75
+ let q_byte = get_byte(q_packed, k);
76
+ let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
77
+ let q_lo = (f32(q_byte & 0xF) - 8.0) * d;
78
+ local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k];
79
+ local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16];
80
+ }
81
+ }
82
+ }
83
+ return local_sum;
84
+ }
85
+ #endif
86
+
87
+ struct MulMatParams {
88
+ offset_src0: u32,
89
+ offset_src1: u32,
90
+ offset_dst: u32,
91
+ m: u32,
92
+ n: u32,
93
+ k: u32,
94
+ stride_01: u32,
95
+ stride_11: u32,
96
+ stride_02: u32,
97
+ stride_12: u32,
98
+ stride_03: u32,
99
+ stride_13: u32,
100
+ bs02: u32,
101
+ bs03: u32,
102
+ broadcast2: u32,
103
+ broadcast3: u32
104
+ };
105
+
106
+ // SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included
107
+ @group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
108
+ @group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
109
+ @group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed)
110
+
111
+ @group(0) @binding(3) var<uniform> params: MulMatParams;
112
+
113
+ const THREADS_PER_OUTPUT = WG_SIZE / OUTPUTS_PER_WG;
114
+
115
+ // Shared memory for collaborative loading and reduction
116
+ var<workgroup> shared_vector: array<SRC1_TYPE, TILE_K/VEC_SIZE>; // Cache vector tile
117
+ var<workgroup> partial_sums: array<f32, WG_SIZE>; // For reduction
118
+
119
+ @compute @workgroup_size(WG_SIZE)
120
+ fn main(
121
+ @builtin(local_invocation_id) local_id: vec3<u32>,
122
+ @builtin(workgroup_id) wg_id: vec3<u32>,
123
+ @builtin(num_workgroups) num_wg: vec3<u32>) {
124
+ let thread_id = local_id.x;
125
+
126
+ // Handle batch dimensions
127
+ let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
128
+ let wg_linear = wg_id.y * num_wg.x + wg_id.x;
129
+ let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG;
130
+ let batch_idx = wg_linear / output_groups;
131
+ if (batch_idx >= total_batches) {
132
+ return;
133
+ }
134
+
135
+ // Which of the outputs does this thread belong to?
136
+ let thread_group = thread_id / THREADS_PER_OUTPUT;
137
+ let thread_in_group = thread_id % THREADS_PER_OUTPUT;
138
+
139
+ // Each workgroup computes OUTPUTS_PER_WG consecutive outputs
140
+ let output_row = (wg_linear % output_groups) * OUTPUTS_PER_WG + thread_group;
141
+
142
+ let dst2_stride = params.m * params.n;
143
+ let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);
144
+ let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
145
+ let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);
146
+ let src03_idx = dst3_idx / params.broadcast3;
147
+ let src13_idx = dst3_idx;
148
+ let src02_idx = dst2_idx / params.broadcast2;
149
+ let src12_idx = dst2_idx;
150
+
151
+ let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + output_row * params.stride_01;
152
+ let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
153
+ let dst_idx = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + output_row;
154
+
155
+ var local_sum = 0.0;
156
+
157
+ // Each thread processes multiple K elements and accumulates
158
+ for (var k_tile = 0u; k_tile < params.k; k_tile += TILE_K) {
159
+ let tile_size = min(TILE_K, params.k - k_tile);
160
+
161
+ // Cooperatively load vector tile into shared memory (all threads)
162
+ for (var i = thread_id * VEC_SIZE; i < tile_size; i += WG_SIZE * VEC_SIZE) {
163
+ shared_vector[i / VEC_SIZE] = src1[(src1_idx_base + k_tile + i) / VEC_SIZE];
164
+ }
165
+
166
+ workgroupBarrier();
167
+
168
+ if (output_row < params.m) {
169
+ local_sum += mul_acc(thread_in_group, tile_size, src0_idx_base, k_tile);
170
+ }
171
+
172
+ workgroupBarrier();
173
+ }
174
+
175
+ // Store partial sums and reduce within each partition
176
+ partial_sums[thread_id] = local_sum;
177
+ workgroupBarrier();
178
+ let group_base = thread_group * THREADS_PER_OUTPUT;
179
+ let thread_base = group_base + thread_in_group;
180
+ var offset: u32 = THREADS_PER_OUTPUT / 2;
181
+ while (offset > 0) {
182
+ if (thread_in_group < offset) {
183
+ partial_sums[thread_base] += partial_sums[thread_base + offset];
184
+ }
185
+ offset = offset / 2;
186
+ workgroupBarrier();
187
+ }
188
+
189
+ // Store back to global memory
190
+ if (output_row < params.m && thread_group % VEC_SIZE == 0 && thread_in_group == 0) {
191
+ dst[dst_idx / VEC_SIZE] = store_val(group_base);
192
+ }
193
+ }
194
+